Tensorflow Seq2seq代码总结

1,161 阅读1分钟

encoder

# 首先构造单个rnn cell
encoder_f_cell = LSTMCell(self.hidden_size)
encoder_b_cell = LSTMCell(self.hidden_size)
 (encoder_fw_outputs, encoder_bw_outputs),
 (encoder_fw_final_state, encoder_bw_final_state) = \
        tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_f_cell,
                                            cell_bw=encoder_b_cell,
                                            inputs=self.encoder_inputs_embedded,
                                            sequence_length=self.encoder_inputs_actual_length,
                                            dtype=tf.float32, time_major=True)

time_major:如果是True,则输入需要是T×B×E,T代表时间序列的长度,B代表batch size,E代表词向量的维度。否则,为B×T×E。输出也是类似。

outputs:针对所有时间序列上的输出。

final_state:只是最后一个时间节点的状态。

Helper

常用的Helper:

  • TrainingHelper:适用于训练的helper。
  • InferenceHelper:适用于测试的helper。
  • GreedyEmbeddingHelper:适用于测试中采用Greedy策略sample的helper。

reuse

在训练和推断阶段,模型的decoder部分表现不同,而我们需要两者共享权值,所以需要变量管理.在生成上下文管理器时,将参数reuse设置为True.这样tf.get_variable函数将直接获取已经声明的变量,代码如下:

        with tf.variable_scope("decode"):
            training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,
                                                                sequence_length=target_sequence_length,
                                                                time_major=False)

            training_decoder = tf.contrib.seq2seq.BasicDecoder(cell, training_helper, encoder_state, output_layer)
            training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder, impute_finished=True,
                                                                              maximum_iterations=max_target_sequence_length)
        with tf.variable_scope("decode", reuse=True):
            start_tokens = tf.tile(tf.constant([target_letter_to_int['<BOS>']], dtype=tf.int32), [self.batch_size],name='start_token')
            predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings, start_tokens,
                                                                         target_letter_to_int['<EOS>'])

            predicting_decoder = tf.contrib.seq2seq.BasicDecoder(cell,predicting_helper,encoder_state,output_layer)
            predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,impute_finished=True,
                                                                                maximum_iterations=max_target_sequence_length)

Loss Function

tf.contrib.seq2seq.sequence_loss可以直接计算序列的损失函数,重要参数: logits:尺寸[batch_size, sequence_length, num_decoder_symbols]

targets:尺寸[batch_size, sequence_length],不用做one_hot。

weights:[batch_size, sequence_length],即mask,滤去padding的loss计算,使loss计算更准确。

其中num_decoder_symbols指的是taget的词表大小,如果在阅读理解当中,source和target词表通常是一样,在机器翻译中则完全不同.还有就是要注意sequence_length是batch中的长度还是所有数据上的长度,一般来说可以取batch中的长度.