encoder
# 首先构造单个rnn cell
encoder_f_cell = LSTMCell(self.hidden_size)
encoder_b_cell = LSTMCell(self.hidden_size)
(encoder_fw_outputs, encoder_bw_outputs),
(encoder_fw_final_state, encoder_bw_final_state) = \
tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_f_cell,
cell_bw=encoder_b_cell,
inputs=self.encoder_inputs_embedded,
sequence_length=self.encoder_inputs_actual_length,
dtype=tf.float32, time_major=True)
time_major:如果是True,则输入需要是T×B×E,T代表时间序列的长度,B代表batch size,E代表词向量的维度。否则,为B×T×E。输出也是类似。
outputs:针对所有时间序列上的输出。
final_state:只是最后一个时间节点的状态。
Helper
常用的Helper:
- TrainingHelper:适用于训练的helper。
- InferenceHelper:适用于测试的helper。
- GreedyEmbeddingHelper:适用于测试中采用Greedy策略sample的helper。
reuse
在训练和推断阶段,模型的decoder部分表现不同,而我们需要两者共享权值,所以需要变量管理.在生成上下文管理器时,将参数reuse设置为True.这样tf.get_variable函数将直接获取已经声明的变量,代码如下:
with tf.variable_scope("decode"):
training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,
sequence_length=target_sequence_length,
time_major=False)
training_decoder = tf.contrib.seq2seq.BasicDecoder(cell, training_helper, encoder_state, output_layer)
training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder, impute_finished=True,
maximum_iterations=max_target_sequence_length)
with tf.variable_scope("decode", reuse=True):
start_tokens = tf.tile(tf.constant([target_letter_to_int['<BOS>']], dtype=tf.int32), [self.batch_size],name='start_token')
predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings, start_tokens,
target_letter_to_int['<EOS>'])
predicting_decoder = tf.contrib.seq2seq.BasicDecoder(cell,predicting_helper,encoder_state,output_layer)
predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,impute_finished=True,
maximum_iterations=max_target_sequence_length)
Loss Function
tf.contrib.seq2seq.sequence_loss可以直接计算序列的损失函数,重要参数:
logits:尺寸[batch_size, sequence_length, num_decoder_symbols]
targets:尺寸[batch_size, sequence_length],不用做one_hot。
weights:[batch_size, sequence_length],即mask,滤去padding的loss计算,使loss计算更准确。
其中num_decoder_symbols指的是taget的词表大小,如果在阅读理解当中,source和target词表通常是一样,在机器翻译中则完全不同.还有就是要注意sequence_length是batch中的长度还是所有数据上的长度,一般来说可以取batch中的长度.