Bi-LSTM 的 tensorflow-1.14 实现

1,494 阅读4分钟

原理自己上网找,这里只用简单的示例来展示原理,输入一句话,用 BI-LSTM 预测下一个词。

import tensorflow as tf
import numpy as np

sentences = 'Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud exercitation'
# UNK 单词作为占位符号
words = ['UNK'] + list(set([word for word in sentences.split()]))
# 生成单词和索引的映射关系
word2idx = {v:k for k,v in enumerate(words)}
idx2word = {k:v for k,v in enumerate(words)}

V = len(words)  # 字典大小
step = len(sentence.split())  # 这里将 sentence 的长度作为了输入句子的长度
hidden = 50   # 隐层大小
dim = 50  # 词向量维度
def make_batch(sentence):
    input_batch, target_batch = [],[]
    words = sentence.split()
    for i,word in enumerate(words[:-1]):
        input = [word2idx[word] for word in words[:i+1]]  # 将不同长度的句子转换成索引列表,当作输入的句子
        input += [0] * (step - len(input))  # 将长度不足 step 的句子用 UNK 补齐
        target = word2idx[words[i+1]]  # 用 input 句子的下一个单词索引作为标签
        input_batch.append(input)
        target_batch.append(np.eye(V)[target])  # 将 target 转换为 one-hot 编码,之后的 softmax_cross_entropy_with_logits_v2 会用到
    return input_batch, target_batch
    
tf.reset_default_graph()

# 初始化 词向量
embedding = tf.get_variable(name="embedding", shape=[V, dim], initializer=tf.random_normal_initializer)
X = tf.placeholder(tf.int32, [None, step])
# 词嵌入
XX = tf.nn.embedding_lookup(embedding, X)
Y = tf.placeholder(tf.int32, [None, V])

# 前向 LSTM
lstm_fw_cell = tf.nn.rnn_cell.LSTMCell(hidden)
# 反向 LSTM
lstm_bw_cell = tf.nn.rnn_cell.LSTMCell(hidden)

# 隐层计算
# outputs : (fw=[batch_size, step, hidden], bw=[batch_size, step, hidden])
# states : (fw=(c=[batch_size, hidden], h=[batch_size, hidden]), bw=(c=[batch_size, hidden], h=[batch_size, hidden]))
outputs, states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, XX, dtype=tf.float32)

# 将最后输出的正反方向的向量进行拼接
outputs = tf.concat([outputs[0], outputs[1]], 2) # [batch_size, step, 2*hidden]
outputs = tf.transpose(outputs, [1,0,2])[-1]    # [batch_size, 2*hidden]

# 连接隐层和分类器的权重参数和偏置参数
W = tf.Variable(tf.random_normal([2*hidden, V]))
b = tf.Variable(tf.random_normal([V]))

# 结果概率
logits = tf.matmul(outputs, W) + b  # [batch_size, V]

# 计算损失并进行优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=Y))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)

# 预测
prediction = tf.argmax(logits, 1)


init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 生产输入和标签
input_batch, target_batch = make_batch(sentence)
# 训练模型
for epoch in range(3000):
    _, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
    if (epoch+1)%500 == 0:
        print("epoch=",'%04d'%(epoch+1)," loss=", '%04f'%(loss))
        
# 预测
predict = sess.run([prediction], feed_dict={X:input_batch})
for i,idxs in enumerate(input_batch):
    print(" ".join([idx2word[idx] for idx in idxs if idx!=0]) ,' 预测单词:',idx2word[predict[0][i]], ' 真实单词:', idx2word[np.argmax(target_batch[i])])       
    

结果打印:

epoch= 0500  loss= 0.001467
epoch= 1000  loss= 0.000409
epoch= 1500  loss= 0.000199
epoch= 2000  loss= 0.000116
epoch= 2500  loss= 0.000074
epoch= 3000  loss= 0.000050
Lorem  预测单词: ipsum  真实单词: ipsum
Lorem ipsum  预测单词: dolor  真实单词: dolor
Lorem ipsum dolor  预测单词: sit  真实单词: sit
Lorem ipsum dolor sit  预测单词: amet  真实单词: amet
Lorem ipsum dolor sit amet  预测单词: consectetur  真实单词: consectetur
Lorem ipsum dolor sit amet consectetur  预测单词: adipisicing  真实单词: adipisicing
Lorem ipsum dolor sit amet consectetur adipisicing  预测单词: elit  真实单词: elit
Lorem ipsum dolor sit amet consectetur adipisicing elit  预测单词: sed  真实单词: sed
Lorem ipsum dolor sit amet consectetur adipisicing elit sed  预测单词: do  真实单词: do
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do  预测单词: eiusmod  真实单词: eiusmod
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod  预测单词: tempor  真实单词: tempor
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor  预测单词: incididunt  真实单词: incididunt
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt  预测单词: ut  真实单词: ut
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut  预测单词: labore  真实单词: labore
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore  预测单词: et  真实单词: et
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et  预测单词: dolore  真实单词: dolore
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore  预测单词: magna  真实单词: magna
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna  预测单词: aliqua  真实单词: aliqua
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua  预测单词: Ut  真实单词: Ut
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut  预测单词: enim  真实单词: enim
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim  预测单词: ad  真实单词: ad
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad  预测单词: minim  真实单词: minim
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim  预测单词: veniam  真实单词: veniam
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam  预测单词: quis  真实单词: quis
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis  预测单词: nostrud  真实单词: nostrud
Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud  预测单词: exercitation  真实单词: exercitation