tensorflow 实现 Seq2seq

320 阅读3分钟

环境:

tensorflow:1.15.0
python:3.6

架构:

训练:Teacher Forcing + Attention 
测试:Attention + Beamsearch 

s2s.py 代码如下

import tensorflow as tf


class Seq2seq(object):

    def build_inputs(self, config):
        self.seq_inputs = tf.placeholder(shape=(config.batch_size, None), dtype=tf.int32, name='seq_inputs')
        self.seq_inputs_length = tf.placeholder(shape=(config.batch_size,), dtype=tf.int32, name='seq_inputs_length')
        self.seq_targets = tf.placeholder(shape=(config.batch_size, None), dtype=tf.int32, name='seq_targets')
        self.seq_targets_length = tf.placeholder(shape=(config.batch_size,), dtype=tf.int32, name='seq_targets_length')

    def __init__(self, config, w2i_target, useTeacherForcing=True, useBeamSearch=1):

        self.build_inputs(config)

        with tf.variable_scope("encoder"):

            encoder_embedding = tf.Variable(tf.random_uniform([config.source_vocab_size, config.embedding_dim]),
                                            dtype=tf.float32, name='encoder_embedding')
            encoder_inputs_embedded = tf.nn.embedding_lookup(encoder_embedding, self.seq_inputs)

            ((encoder_fw_outputs, encoder_bw_outputs),
             (encoder_fw_final_state, encoder_bw_final_state)) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=tf.nn.rnn_cell.GRUCell(config.hidden_dim),
                cell_bw=tf.nn.rnn_cell.GRUCell(config.hidden_dim),
                inputs=encoder_inputs_embedded,
                sequence_length=self.seq_inputs_length,
                dtype=tf.float32,
                time_major=False
            )
            encoder_state = tf.add(encoder_fw_final_state, encoder_bw_final_state)
            encoder_outputs = tf.add(encoder_fw_outputs, encoder_bw_outputs)

        with tf.variable_scope("decoder"):

            decoder_embedding = tf.Variable(tf.random_uniform([config.target_vocab_size, config.embedding_dim]),
                                            dtype=tf.float32, name='decoder_embedding')

            tokens_go = tf.ones([config.batch_size], dtype=tf.int32, name='tokens_GO') * w2i_target["_GO"]

            if useTeacherForcing:
                decoder_inputs = tf.concat([tf.reshape(tokens_go, [-1, 1]), self.seq_targets[:, :-1]], 1)
                helper = tf.contrib.seq2seq.TrainingHelper(tf.nn.embedding_lookup(decoder_embedding, decoder_inputs),
                                                           self.seq_targets_length)
            else:
                helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embedding, tokens_go, w2i_target["_EOS"])

            with tf.variable_scope("gru_cell"):
                decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)

                # attention + beamsearch 训练的时候 beamsearch=1,测试的时候可以大于或者等于1
                if useBeamSearch > 1:
                    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=useBeamSearch)
                    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(self.seq_inputs_length,
                                                                          multiplier=useBeamSearch)
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=config.hidden_dim,
                                                                               memory=tiled_encoder_outputs,
                                                                               memory_sequence_length=tiled_sequence_length)
                    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism)
                    tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(encoder_state,
                                                                              multiplier=useBeamSearch)
                    tiled_decoder_initial_state = decoder_cell.zero_state(
                        batch_size=config.batch_size * useBeamSearch, dtype=tf.float32)
                    tiled_decoder_initial_state = tiled_decoder_initial_state.clone(
                        cell_state=tiled_encoder_final_state)
                    decoder_initial_state = tiled_decoder_initial_state
                else:
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=config.hidden_dim,
                                                                               memory=encoder_outputs,
                                                                               memory_sequence_length=self.seq_inputs_length)
                    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism)
                    decoder_initial_state = decoder_cell.zero_state(batch_size=config.batch_size, dtype=tf.float32)
                    decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)

            # decoder
            if useBeamSearch > 1:
                decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell, decoder_embedding, tokens_go,
                                                               w2i_target["_EOS"], decoder_initial_state,
                                                               beam_width=useBeamSearch,
                                                               output_layer=tf.layers.Dense(config.target_vocab_size))
            else:
                decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state,
                                                          output_layer=tf.layers.Dense(config.target_vocab_size))

            # dynamic decide
            decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder,
                                                                                                       maximum_iterations=tf.reduce_max(
                                                                                                           self.seq_targets_length))

        if useBeamSearch > 1:
            self.out = decoder_outputs.predicted_ids[:, :, 0]
        else:
            decoder_logits = decoder_outputs.rnn_output
            self.out = tf.argmax(decoder_logits, 2)

            sequence_mask = tf.sequence_mask(self.seq_targets_length, dtype=tf.float32)
            self.loss = tf.contrib.seq2seq.sequence_loss(logits=decoder_logits, targets=self.seq_targets,
                                                         weights=sequence_mask)

            self.train_op = tf.train.AdamOptimizer(learning_rate=config.learning_rate).minimize(self.loss)

train.py 代码如下,直接运行 python 文件即可

import tensorflow as tf
import random
import time
from os import path
from s2s import Seq2seq

tf_config = tf.ConfigProto(allow_soft_placement=True)
tf_config.gpu_options.allow_growth = True

class Config(object):
    embedding_dim = 100
    hidden_dim = 50
    batch_size = 128
    learning_rate = 0.005
    source_vocab_size = None
    target_vocab_size = None


def load_data():
    num2en = {"1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight",
              "9": "nine", "0": "zero"}
    docs_source = []
    docs_target = []
    for i in range(10000):
        doc_len = random.randint(1, 8)
        doc_source = []
        doc_target = []
        for j in range(doc_len):
            num = str(random.randint(0, 9))
            doc_source.append(num)
            doc_target.append(num2en[num])
        docs_source.append(doc_source)
        docs_target.append(doc_target)

    return docs_source, docs_target


def make_vocab(docs):
    w2i = {"_PAD": 0, "_GO": 1, "_EOS": 2}
    i2w = {0: "_PAD", 1: "_GO", 2: "_EOS"}
    for doc in docs:
        for w in doc:
            if w not in w2i:
                i2w[len(w2i)] = w
                w2i[w] = len(w2i)
    return w2i, i2w


def doc_to_seq(docs):
    w2i = {"_PAD": 0, "_GO": 1, "_EOS": 2}
    i2w = {0: "_PAD", 1: "_GO", 2: "_EOS"}
    seqs = []
    for doc in docs:
        seq = []
        for w in doc:
            if w not in w2i:
                i2w[len(w2i)] = w
                w2i[w] = len(w2i)
            seq.append(w2i[w])
        seqs.append(seq)
    return seqs, w2i, i2w


def get_batch(docs_source, w2i_source, docs_target, w2i_target, batch_size):
    ps = []
    while len(ps) < batch_size:
        ps.append(random.randint(0, len(docs_source) - 1))

    source_batch = []
    target_batch = []

    source_lens = [len(docs_source[p]) for p in ps]
    target_lens = [len(docs_target[p]) + 1 for p in ps]

    max_source_len = max(source_lens)
    max_target_len = max(target_lens)

    for p in ps:
        source_seq = [w2i_source[w] for w in docs_source[p]] + [w2i_source["_PAD"]] * (
                    max_source_len - len(docs_source[p]))
        target_seq = [w2i_target[w] for w in docs_target[p]] + [w2i_target["_EOS"]] + [w2i_target["_PAD"]] * (
                    max_target_len - 1 - len(docs_target[p]))
        source_batch.append(source_seq)
        target_batch.append(target_seq)

    return source_batch, source_lens, target_batch, target_lens


if __name__ == "__main__":

    print("(1) load data......")
    docs_source, docs_target = load_data()
    w2i_source, i2w_source = make_vocab(docs_source)
    w2i_target, i2w_target = make_vocab(docs_target)

    print("(2) build model......")
    config = Config()
    config.source_vocab_size = len(w2i_source)
    config.target_vocab_size = len(w2i_target)
    model = Seq2seq(config=config, w2i_target=w2i_target)

    print("(3) run model......")
    batches = 3000
    print_every = 100

    model_dir = "./checkpoint"
    model_path = model_dir+"/model.ckpt"

    with tf.Session(config=tf_config) as sess:
        tf.summary.FileWriter('graph', sess.graph)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        if path.isfile(model_path):
            print("restore model")
            saver.restore(sess, model_path)

        losses = []
        total_loss = 0
        for batch in range(batches):
            source_batch, source_lens, target_batch, target_lens = get_batch(docs_source, w2i_source, docs_target,
                                                                             w2i_target, config.batch_size)

            feed_dict = {
                model.seq_inputs: source_batch,
                model.seq_inputs_length: source_lens,
                model.seq_targets: target_batch,
                model.seq_targets_length: target_lens
            }

            loss, _ = sess.run([model.loss, model.train_op], feed_dict)
            total_loss += loss

            if batch % print_every == 0 and batch > 0:
                print_loss = total_loss if batch == 0 else total_loss / print_every
                losses.append(print_loss)
                total_loss = 0
                print("-----------------------------")
                print("batch:", batch, "/", batches)
                print("time:", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
                print("loss:", print_loss)

                print("samples:\n")
                predict_batch = sess.run(model.out, feed_dict)
                for i in range(3):
                    print("in:", [i2w_source[num] for num in source_batch[i] if i2w_source[num] != "_PAD"])
                    print("out:", [i2w_target[num] for num in predict_batch[i] if i2w_target[num] != "_PAD"])
                    print("tar:", [i2w_target[num] for num in target_batch[i] if i2w_target[num] != "_PAD"])
                    print("")
                print(saver.save(sess, model_path))
        print(losses)

       

运行结果如下:

(1)load data......
(2) build model......
(3) run model......
-----------------------------
batch: 100 / 3000
time: 2020-10-21 10:46:28
loss: 0.9439343801140785
samples:

in: ['5']
out: ['five', 'five', 'five', 'five', 'five', 'five', 'five', 'five', 'five']
tar: ['five', '_EOS']

in: ['4', '8', '8', '1', '9', '9', '3']
out: ['four', 'eight', 'eight', 'one', 'nine', 'nine', 'three', '_EOS', '_EOS']
tar: ['four', 'eight', 'eight', 'one', 'nine', 'nine', 'three', '_EOS']

in: ['2', '8', '5', '5', '7', '4']
out: ['two', 'eight', 'five', 'five', 'seven', '_EOS', '_EOS', '_EOS', '_EOS']
tar: ['two', 'eight', 'five', 'five', 'seven', 'four', '_EOS']

-----------------------------
batch: 200 / 3000
time: 2020-10-21 10:46:34
loss: 0.39233128368854525
samples:

in: ['8', '7', '8', '2', '7']
out: ['eight', 'seven', 'eight', 'two', 'seven', '_EOS', '_EOS', '_EOS', '_EOS']
tar: ['eight', 'seven', 'eight', 'two', 'seven', '_EOS']

in: ['0', '7']
out: ['zero', 'seven', '_EOS', '_EOS', '_EOS', '_EOS', '_EOS', '_EOS', '_EOS']
tar: ['zero', 'seven', '_EOS']

in: ['6', '5', '2', '7', '8', '7']
out: ['six', 'five', 'two', 'seven', 'eight', 'seven', '_EOS', '_EOS', '_EOS']
tar: ['six', 'five', 'two', 'seven', 'eight', 'seven', '_EOS']

......

test.py 代码如下,直接运行 python 文件即可

from train import *

tf.reset_default_graph()
tf_config = tf.ConfigProto(allow_soft_placement=True)
tf_config.gpu_options.allow_growth = True

model_path = "./checkpoint/model.ckpt"

if __name__ == "__main__":
    print("(1) reload data......")
    docs_source, docs_target = load_data()
    w2i_source, i2w_source = make_vocab(docs_source)
    w2i_target, i2w_target = make_vocab(docs_target)

    print("(2) build model......")
    config = Config()
    config.source_vocab_size = len(w2i_source)
    config.target_vocab_size = len(w2i_target)
    model = Seq2seq(config=config, w2i_target=w2i_target, useTeacherForcing=False,  useBeamSearch=1)

    print("(3) reload model......")
    print_every = 100
    max_target_len = 20

    with tf.Session(config=tf_config) as sess:
        saver = tf.train.Saver()
        saver.restore(sess, model_path)
        print("(4) test......")
        while True:
            text = input("输入数字字符串:")

            source_batch = []
            source_lens = []
            target_batch = []
            target_batch_length = []

            text_list = []
            if text:
                text_list = [c for c in text]
                tmp = []
                for i, c in enumerate(text_list):
                    tmp.append(w2i_source[c])
                source_batch.append(tmp)
                target_batch.append(tmp + [w2i_target["_EOS"]] + [w2i_target["_PAD"]] * (max_target_len - 1 - len(tmp)))

            max_source_length = len(text_list)
            zero_padding = [[w2i_source["_PAD"]] * max_source_length]
            source_batch.extend(zero_padding * (Config.batch_size - 1))
            source_lens = [len(batch) for batch in source_batch]

            eos = [w2i_target["_EOS"]] + [w2i_target["_PAD"]] * (max_target_len - 1)
            target_batch.extend([eos] * (Config.batch_size - 1))
            target_batch_length = [len(batch) + 1 for batch in target_batch]

            feed_dict = {
                model.seq_inputs: source_batch,
                model.seq_inputs_length: source_lens,
                model.seq_targets: [[0] * max_target_len] * len(source_batch),
                model.seq_targets_length: [max_target_len] * len(source_batch)
            }

            print("samples:\n")
            predict_batch = sess.run(model.out, feed_dict)
            for i in range(1):
                print("in:", [i2w_source[num] for num in source_batch[i] if i2w_source[num] != "_PAD"])
                print("pre:", [i2w_target[num] for num in predict_batch[i] if i2w_target[num] != "_PAD"])
                print("tar:", [i2w_target[num] for num in target_batch[i] if i2w_target[num] != "_PAD"])
                print("")

运行结果如下:

(1) reload data......
(2) build model......
(3) reload model......
(4) test......
输入数字字符串:123123
samples:

in: ['1', '2', '3', '1', '2', '3']
pre: ['one', 'two', 'three', 'one', 'two', 'three', '_EOS']
tar: ['one', 'two', 'three', 'one', 'two', 'three', '_EOS']