环境:
tensorflow:1.15.0
python:3.6
架构:
训练:Teacher Forcing + Attention
测试:Attention + Beamsearch
s2s.py 代码如下
import tensorflow as tf
class Seq2seq(object):
def build_inputs(self, config):
self.seq_inputs = tf.placeholder(shape=(config.batch_size, None), dtype=tf.int32, name='seq_inputs')
self.seq_inputs_length = tf.placeholder(shape=(config.batch_size,), dtype=tf.int32, name='seq_inputs_length')
self.seq_targets = tf.placeholder(shape=(config.batch_size, None), dtype=tf.int32, name='seq_targets')
self.seq_targets_length = tf.placeholder(shape=(config.batch_size,), dtype=tf.int32, name='seq_targets_length')
def __init__(self, config, w2i_target, useTeacherForcing=True, useBeamSearch=1):
self.build_inputs(config)
with tf.variable_scope("encoder"):
encoder_embedding = tf.Variable(tf.random_uniform([config.source_vocab_size, config.embedding_dim]),
dtype=tf.float32, name='encoder_embedding')
encoder_inputs_embedded = tf.nn.embedding_lookup(encoder_embedding, self.seq_inputs)
((encoder_fw_outputs, encoder_bw_outputs),
(encoder_fw_final_state, encoder_bw_final_state)) = tf.nn.bidirectional_dynamic_rnn(
cell_fw=tf.nn.rnn_cell.GRUCell(config.hidden_dim),
cell_bw=tf.nn.rnn_cell.GRUCell(config.hidden_dim),
inputs=encoder_inputs_embedded,
sequence_length=self.seq_inputs_length,
dtype=tf.float32,
time_major=False
)
encoder_state = tf.add(encoder_fw_final_state, encoder_bw_final_state)
encoder_outputs = tf.add(encoder_fw_outputs, encoder_bw_outputs)
with tf.variable_scope("decoder"):
decoder_embedding = tf.Variable(tf.random_uniform([config.target_vocab_size, config.embedding_dim]),
dtype=tf.float32, name='decoder_embedding')
tokens_go = tf.ones([config.batch_size], dtype=tf.int32, name='tokens_GO') * w2i_target["_GO"]
if useTeacherForcing:
decoder_inputs = tf.concat([tf.reshape(tokens_go, [-1, 1]), self.seq_targets[:, :-1]], 1)
helper = tf.contrib.seq2seq.TrainingHelper(tf.nn.embedding_lookup(decoder_embedding, decoder_inputs),
self.seq_targets_length)
else:
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embedding, tokens_go, w2i_target["_EOS"])
with tf.variable_scope("gru_cell"):
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
# attention + beamsearch 训练的时候 beamsearch=1,测试的时候可以大于或者等于1
if useBeamSearch > 1:
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=useBeamSearch)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(self.seq_inputs_length,
multiplier=useBeamSearch)
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=config.hidden_dim,
memory=tiled_encoder_outputs,
memory_sequence_length=tiled_sequence_length)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(encoder_state,
multiplier=useBeamSearch)
tiled_decoder_initial_state = decoder_cell.zero_state(
batch_size=config.batch_size * useBeamSearch, dtype=tf.float32)
tiled_decoder_initial_state = tiled_decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
decoder_initial_state = tiled_decoder_initial_state
else:
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=config.hidden_dim,
memory=encoder_outputs,
memory_sequence_length=self.seq_inputs_length)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism)
decoder_initial_state = decoder_cell.zero_state(batch_size=config.batch_size, dtype=tf.float32)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
# decoder
if useBeamSearch > 1:
decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell, decoder_embedding, tokens_go,
w2i_target["_EOS"], decoder_initial_state,
beam_width=useBeamSearch,
output_layer=tf.layers.Dense(config.target_vocab_size))
else:
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state,
output_layer=tf.layers.Dense(config.target_vocab_size))
# dynamic decide
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder,
maximum_iterations=tf.reduce_max(
self.seq_targets_length))
if useBeamSearch > 1:
self.out = decoder_outputs.predicted_ids[:, :, 0]
else:
decoder_logits = decoder_outputs.rnn_output
self.out = tf.argmax(decoder_logits, 2)
sequence_mask = tf.sequence_mask(self.seq_targets_length, dtype=tf.float32)
self.loss = tf.contrib.seq2seq.sequence_loss(logits=decoder_logits, targets=self.seq_targets,
weights=sequence_mask)
self.train_op = tf.train.AdamOptimizer(learning_rate=config.learning_rate).minimize(self.loss)
train.py 代码如下,直接运行 python 文件即可
import tensorflow as tf
import random
import time
from os import path
from s2s import Seq2seq
tf_config = tf.ConfigProto(allow_soft_placement=True)
tf_config.gpu_options.allow_growth = True
class Config(object):
embedding_dim = 100
hidden_dim = 50
batch_size = 128
learning_rate = 0.005
source_vocab_size = None
target_vocab_size = None
def load_data():
num2en = {"1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight",
"9": "nine", "0": "zero"}
docs_source = []
docs_target = []
for i in range(10000):
doc_len = random.randint(1, 8)
doc_source = []
doc_target = []
for j in range(doc_len):
num = str(random.randint(0, 9))
doc_source.append(num)
doc_target.append(num2en[num])
docs_source.append(doc_source)
docs_target.append(doc_target)
return docs_source, docs_target
def make_vocab(docs):
w2i = {"_PAD": 0, "_GO": 1, "_EOS": 2}
i2w = {0: "_PAD", 1: "_GO", 2: "_EOS"}
for doc in docs:
for w in doc:
if w not in w2i:
i2w[len(w2i)] = w
w2i[w] = len(w2i)
return w2i, i2w
def doc_to_seq(docs):
w2i = {"_PAD": 0, "_GO": 1, "_EOS": 2}
i2w = {0: "_PAD", 1: "_GO", 2: "_EOS"}
seqs = []
for doc in docs:
seq = []
for w in doc:
if w not in w2i:
i2w[len(w2i)] = w
w2i[w] = len(w2i)
seq.append(w2i[w])
seqs.append(seq)
return seqs, w2i, i2w
def get_batch(docs_source, w2i_source, docs_target, w2i_target, batch_size):
ps = []
while len(ps) < batch_size:
ps.append(random.randint(0, len(docs_source) - 1))
source_batch = []
target_batch = []
source_lens = [len(docs_source[p]) for p in ps]
target_lens = [len(docs_target[p]) + 1 for p in ps]
max_source_len = max(source_lens)
max_target_len = max(target_lens)
for p in ps:
source_seq = [w2i_source[w] for w in docs_source[p]] + [w2i_source["_PAD"]] * (
max_source_len - len(docs_source[p]))
target_seq = [w2i_target[w] for w in docs_target[p]] + [w2i_target["_EOS"]] + [w2i_target["_PAD"]] * (
max_target_len - 1 - len(docs_target[p]))
source_batch.append(source_seq)
target_batch.append(target_seq)
return source_batch, source_lens, target_batch, target_lens
if __name__ == "__main__":
print("(1) load data......")
docs_source, docs_target = load_data()
w2i_source, i2w_source = make_vocab(docs_source)
w2i_target, i2w_target = make_vocab(docs_target)
print("(2) build model......")
config = Config()
config.source_vocab_size = len(w2i_source)
config.target_vocab_size = len(w2i_target)
model = Seq2seq(config=config, w2i_target=w2i_target)
print("(3) run model......")
batches = 3000
print_every = 100
model_dir = "./checkpoint"
model_path = model_dir+"/model.ckpt"
with tf.Session(config=tf_config) as sess:
tf.summary.FileWriter('graph', sess.graph)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
if path.isfile(model_path):
print("restore model")
saver.restore(sess, model_path)
losses = []
total_loss = 0
for batch in range(batches):
source_batch, source_lens, target_batch, target_lens = get_batch(docs_source, w2i_source, docs_target,
w2i_target, config.batch_size)
feed_dict = {
model.seq_inputs: source_batch,
model.seq_inputs_length: source_lens,
model.seq_targets: target_batch,
model.seq_targets_length: target_lens
}
loss, _ = sess.run([model.loss, model.train_op], feed_dict)
total_loss += loss
if batch % print_every == 0 and batch > 0:
print_loss = total_loss if batch == 0 else total_loss / print_every
losses.append(print_loss)
total_loss = 0
print("-----------------------------")
print("batch:", batch, "/", batches)
print("time:", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
print("loss:", print_loss)
print("samples:\n")
predict_batch = sess.run(model.out, feed_dict)
for i in range(3):
print("in:", [i2w_source[num] for num in source_batch[i] if i2w_source[num] != "_PAD"])
print("out:", [i2w_target[num] for num in predict_batch[i] if i2w_target[num] != "_PAD"])
print("tar:", [i2w_target[num] for num in target_batch[i] if i2w_target[num] != "_PAD"])
print("")
print(saver.save(sess, model_path))
print(losses)
运行结果如下:
(1)load data......
(2) build model......
(3) run model......
-----------------------------
batch: 100 / 3000
time: 2020-10-21 10:46:28
loss: 0.9439343801140785
samples:
in: ['5']
out: ['five', 'five', 'five', 'five', 'five', 'five', 'five', 'five', 'five']
tar: ['five', '_EOS']
in: ['4', '8', '8', '1', '9', '9', '3']
out: ['four', 'eight', 'eight', 'one', 'nine', 'nine', 'three', '_EOS', '_EOS']
tar: ['four', 'eight', 'eight', 'one', 'nine', 'nine', 'three', '_EOS']
in: ['2', '8', '5', '5', '7', '4']
out: ['two', 'eight', 'five', 'five', 'seven', '_EOS', '_EOS', '_EOS', '_EOS']
tar: ['two', 'eight', 'five', 'five', 'seven', 'four', '_EOS']
-----------------------------
batch: 200 / 3000
time: 2020-10-21 10:46:34
loss: 0.39233128368854525
samples:
in: ['8', '7', '8', '2', '7']
out: ['eight', 'seven', 'eight', 'two', 'seven', '_EOS', '_EOS', '_EOS', '_EOS']
tar: ['eight', 'seven', 'eight', 'two', 'seven', '_EOS']
in: ['0', '7']
out: ['zero', 'seven', '_EOS', '_EOS', '_EOS', '_EOS', '_EOS', '_EOS', '_EOS']
tar: ['zero', 'seven', '_EOS']
in: ['6', '5', '2', '7', '8', '7']
out: ['six', 'five', 'two', 'seven', 'eight', 'seven', '_EOS', '_EOS', '_EOS']
tar: ['six', 'five', 'two', 'seven', 'eight', 'seven', '_EOS']
......
test.py 代码如下,直接运行 python 文件即可
from train import *
tf.reset_default_graph()
tf_config = tf.ConfigProto(allow_soft_placement=True)
tf_config.gpu_options.allow_growth = True
model_path = "./checkpoint/model.ckpt"
if __name__ == "__main__":
print("(1) reload data......")
docs_source, docs_target = load_data()
w2i_source, i2w_source = make_vocab(docs_source)
w2i_target, i2w_target = make_vocab(docs_target)
print("(2) build model......")
config = Config()
config.source_vocab_size = len(w2i_source)
config.target_vocab_size = len(w2i_target)
model = Seq2seq(config=config, w2i_target=w2i_target, useTeacherForcing=False, useBeamSearch=1)
print("(3) reload model......")
print_every = 100
max_target_len = 20
with tf.Session(config=tf_config) as sess:
saver = tf.train.Saver()
saver.restore(sess, model_path)
print("(4) test......")
while True:
text = input("输入数字字符串:")
source_batch = []
source_lens = []
target_batch = []
target_batch_length = []
text_list = []
if text:
text_list = [c for c in text]
tmp = []
for i, c in enumerate(text_list):
tmp.append(w2i_source[c])
source_batch.append(tmp)
target_batch.append(tmp + [w2i_target["_EOS"]] + [w2i_target["_PAD"]] * (max_target_len - 1 - len(tmp)))
max_source_length = len(text_list)
zero_padding = [[w2i_source["_PAD"]] * max_source_length]
source_batch.extend(zero_padding * (Config.batch_size - 1))
source_lens = [len(batch) for batch in source_batch]
eos = [w2i_target["_EOS"]] + [w2i_target["_PAD"]] * (max_target_len - 1)
target_batch.extend([eos] * (Config.batch_size - 1))
target_batch_length = [len(batch) + 1 for batch in target_batch]
feed_dict = {
model.seq_inputs: source_batch,
model.seq_inputs_length: source_lens,
model.seq_targets: [[0] * max_target_len] * len(source_batch),
model.seq_targets_length: [max_target_len] * len(source_batch)
}
print("samples:\n")
predict_batch = sess.run(model.out, feed_dict)
for i in range(1):
print("in:", [i2w_source[num] for num in source_batch[i] if i2w_source[num] != "_PAD"])
print("pre:", [i2w_target[num] for num in predict_batch[i] if i2w_target[num] != "_PAD"])
print("tar:", [i2w_target[num] for num in target_batch[i] if i2w_target[num] != "_PAD"])
print("")
运行结果如下:
(1) reload data......
(2) build model......
(3) reload model......
(4) test......
输入数字字符串:123123
samples:
in: ['1', '2', '3', '1', '2', '3']
pre: ['one', 'two', 'three', 'one', 'two', 'three', '_EOS']
tar: ['one', 'two', 'three', 'one', 'two', 'three', '_EOS']