import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.python.layers import core as layers_core def getLayeredCell(layer_size, num_units, input_keep_prob, output_keep_prob=1.0): return rnn.MultiRNNCell([rnn.DropoutWrapper(rnn.BasicLSTMCell(num_units), input_keep_prob, output_keep_prob) for i in range(layer_size)]) def bi_encoder(embed_input, in_seq_len, num_units, layer_size, input_keep_prob): # encode input into a vector bi_layer_size = int(layer_size / 2) encode_cell_fw = getLayeredCell(bi_layer_size, num_units, input_keep_prob) encode_cell_bw = getLayeredCell(bi_layer_size, num_units, input_keep_prob) bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=encode_cell_fw, cell_bw=encode_cell_bw, inputs=embed_input, sequence_length=in_seq_len, dtype=embed_input.dtype, time_major=False) # concat encode output and state encoder_output = tf.concat(bi_encoder_output, -1) encoder_state = [] for layer_id in range(bi_layer_size): encoder_state.append(bi_encoder_state[0][layer_id]) encoder_state.append(bi_encoder_state[1][layer_id]) encoder_state = tuple(encoder_state) return encoder_output, encoder_state def attention_decoder_cell(encoder_output, in_seq_len, num_units, layer_size, input_keep_prob): attention_mechanim = tf.contrib.seq2seq.BahdanauAttention(num_units, encoder_output, in_seq_len, normalize=True) # attention_mechanim = tf.contrib.seq2seq.LuongAttention(num_units, # encoder_output, in_seq_len, scale = True) cell = getLayeredCell(layer_size, num_units, input_keep_prob) cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanim, attention_layer_size=num_units) return cell def decoder_projection(output, output_size): return tf.layers.dense(output, output_size, activation=None, use_bias=False, name='output_mlp') def train_decoder(encoder_output, in_seq_len, target_seq, target_seq_len, encoder_state, num_units, layers, embedding, output_size, input_keep_prob, projection_layer): decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units, layers, input_keep_prob) batch_size = tf.shape(in_seq_len)[0] init_state = decoder_cell.zero_state(batch_size, tf.float32).clone( cell_state=encoder_state) helper = tf.contrib.seq2seq.TrainingHelper( target_seq, target_seq_len, time_major=False) decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, init_state, output_layer=projection_layer) outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=100) return outputs.rnn_output def infer_decoder(encoder_output, in_seq_len, encoder_state, num_units, layers, embedding, output_size, input_keep_prob, projection_layer): decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units, layers, input_keep_prob) batch_size = tf.shape(in_seq_len)[0] init_state = decoder_cell.zero_state(batch_size, tf.float32).clone( cell_state=encoder_state) # TODO: start tokens and end tokens are hard code """ helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding, tf.fill([batch_size], 0), 1) decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, init_state, output_layer=projection_layer) """ decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding, start_tokens=tf.fill([batch_size], 0), end_token=1, initial_state=init_state, beam_width=10, output_layer=projection_layer, length_penalty_weight=1.0) outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=100) return outputs.sample_id def seq2seq(in_seq, in_seq_len, target_seq, target_seq_len, vocab_size, num_units, layers, dropout): in_shape = tf.shape(in_seq) batch_size = in_shape[0] if target_seq != None: input_keep_prob = 1 - dropout else: input_keep_prob = 1 projection_layer = layers_core.Dense(vocab_size, use_bias=False) # embedding input and target sequence with tf.device('/cpu:0'): embedding = tf.get_variable( name='embedding', shape=[vocab_size, num_units]) embed_input = tf.nn.embedding_lookup(embedding, in_seq, name='embed_input') # encode and decode encoder_output, encoder_state = bi_encoder(embed_input, in_seq_len, num_units, layers, input_keep_prob) decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units, layers, input_keep_prob) batch_size = tf.shape(in_seq_len)[0] init_state = decoder_cell.zero_state(batch_size, tf.float32).clone( cell_state=encoder_state) if target_seq != None: embed_target = tf.nn.embedding_lookup(embedding, target_seq, name='embed_target') helper = tf.contrib.seq2seq.TrainingHelper( embed_target, target_seq_len, time_major=False) else: # TODO: start tokens and end tokens are hard code helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding, tf.fill([batch_size], 0), 1) decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, init_state, output_layer=projection_layer) outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=100) if target_seq != None: return outputs.rnn_output else: return outputs.sample_id def seq_loss(output, target, seq_len): target = target[:, 1:] cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, labels=target) batch_size = tf.shape(target)[0] loss_mask = tf.sequence_mask(seq_len, tf.shape(output)[1]) cost = cost * tf.to_float(loss_mask) return tf.reduce_sum(cost) / tf.to_float(batch_size)