携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第8天,点击查看活动详情
在上一篇文章中,我们解释了为什么需要循环神经网络和循环神经网络的结构,并且使用循环神经网络来进行文本分类。
今天我们来进行循环神经网络的文本生成实战。
3.1 文本生成之数据处理
- 使用Shakespeare数据集
- 3.1.1 读取数据
# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
input_filepath = "./shakespeare.txt"
text = open(input_filepath, 'r').read()
print(len(text))
print(text[0:100])
运行结果:
1115394
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You
-
3.1.2 对数据进行处理
生成词表
# 1. generate vocab
# 2. build mapping char->id
# 3. data -> id_data
# 4. abcd -> bcd<eos>
vocab = sorted(set(text))
print(len(vocab))
print(vocab)
运行结果:
65
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
- 3.1.3 对词表进行映射
char2idx = {char:idx for idx, char in enumerate(vocab)}
print(char2idx)
enumerate枚举,枚举之后会给vocal中的每一个元素都有一个idx,将idx作为id就可以了
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
可以发现,已经给每一个字符都赋予了整数值。 换行符对应的id就是0,空格对应得id是1,感叹号的id就是2,对应的就是列表中的位置,因而,就可以用这个列表作为id到字符的映射。
- 3.1.4 把列表变成numpy.array
idx2char = np.array(vocab)
print(idx2char)
运行结果:
['\n' ' ' '!' '$' '&' "'" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E' 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W' 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o' 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']
-
3.1.5 把字符的id转化为id
- 对text中的每一个字符都做一个映射,得到一个id的列表
text_as_int = np.array([char2idx[c] for c in text])
print(text_as_int[0:10])
print(text[0:10])
运行结果:
[18 47 56 57 58 1 15 47 58 47]
First Citi
- 3.1.6 生成输出
- from_tensor_slices生成dataset
def split_input_target(id_text):
"""
abcde -> abcd, bcde
"""
return id_text[0:-1], id_text[1:]
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
seq_length = 100
# 这个batch是未了将字符变成句子
seq_dataset = char_dataset.batch(seq_length + 1,
drop_remainder = True)
for ch_id in char_dataset.take(2):
print(ch_id, idx2char[ch_id.numpy()])
for seq_id in seq_dataset.take(2):
print(seq_id)
print(repr(''.join(idx2char[seq_id.numpy()])))
运行结果:
tf.Tensor(18, shape=(), dtype=int64) F
tf.Tensor(47, shape=(), dtype=int64) i
tf.Tensor(
[18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 14 43 44 53 56 43 1 61 43
1 54 56 53 41 43 43 42 1 39 52 63 1 44 59 56 58 46 43 56 6 1 46 43
39 56 1 51 43 1 57 54 43 39 49 8 0 0 13 50 50 10 0 31 54 43 39 49
6 1 57 54 43 39 49 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10
0 37 53 59 1], shape=(101,), dtype=int64)
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
tf.Tensor(
[39 56 43 1 39 50 50 1 56 43 57 53 50 60 43 42 1 56 39 58 46 43 56 1
58 53 1 42 47 43 1 58 46 39 52 1 58 53 1 44 39 51 47 57 46 12 0 0
13 50 50 10 0 30 43 57 53 50 60 43 42 8 1 56 43 57 53 50 60 43 42 8
0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 18 47 56 57 58 6 1
63 53 59 1 49], shape=(101,), dtype=int64)
'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
- 3.1.7 调用split_input_target来获得输入和输出
seq_dataset = seq_dataset.map(split_input_target)
for item_input, item_output in seq_dataset.take(2):
print(item_input.numpy())
print(item_output.numpy())
运行结果:
[18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 14 43 44 53 56 43 1 61 43
1 54 56 53 41 43 43 42 1 39 52 63 1 44 59 56 58 46 43 56 6 1 46 43
39 56 1 51 43 1 57 54 43 39 49 8 0 0 13 50 50 10 0 31 54 43 39 49
6 1 57 54 43 39 49 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10
0 37 53 59]
[47 56 57 58 1 15 47 58 47 64 43 52 10 0 14 43 44 53 56 43 1 61 43 1
54 56 53 41 43 43 42 1 39 52 63 1 44 59 56 58 46 43 56 6 1 46 43 39
56 1 51 43 1 57 54 43 39 49 8 0 0 13 50 50 10 0 31 54 43 39 49 6
1 57 54 43 39 49 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0
37 53 59 1]
[39 56 43 1 39 50 50 1 56 43 57 53 50 60 43 42 1 56 39 58 46 43 56 1
58 53 1 42 47 43 1 58 46 39 52 1 58 53 1 44 39 51 47 57 46 12 0 0
13 50 50 10 0 30 43 57 53 50 60 43 42 8 1 56 43 57 53 50 60 43 42 8
0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 18 47 56 57 58 6 1
63 53 59 1]
[56 43 1 39 50 50 1 56 43 57 53 50 60 43 42 1 56 39 58 46 43 56 1 58
53 1 42 47 43 1 58 46 39 52 1 58 53 1 44 39 51 47 57 46 12 0 0 13
50 50 10 0 30 43 57 53 50 60 43 42 8 1 56 43 57 53 50 60 43 42 8 0
0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 18 47 56 57 58 6 1 63
53 59 1 49]
- 3.1.8 形成batch
batch_size = 64
buffer_size = 10000
# 这个batch就是未了形成一个batch
# drop_remainder=True:最后一组要是不够一个batch,则丢弃掉
seq_dataset = seq_dataset.shuffle(buffer_size).batch(
batch_size, drop_remainder=True)
-
3.2 构建模型
- 3.2.1定义一个函数来构建模型
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = keras.models.Sequential([
keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape = [batch_size, None]),
keras.layers.SimpleRNN(units = rnn_units,
stateful = True,
recurrent_initializer = 'glorot_uniform',
# 输入和输出都是序列
return_sequences = True),
keras.layers.Dense(vocab_size),
])
return model
model = build_model(
vocab_size = vocab_size,
embedding_dim = embedding_dim,
rnn_units = rnn_units,
batch_size = batch_size)
model.summary()
运行结果:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (64, None, 256) 16640
_________________________________________________________________
simple_rnn (SimpleRNN) (64, None, 1024) 1311744
_________________________________________________________________
dense (Dense) (64, None, 65) 66625
=================================================================
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________
- 3.2.2 用模型进行预测
- 使用方法就是将model直接当成函数来使用
for input_example_batch, target_example_batch in seq_dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape)
运行结果:
(64, 100, 65)
64:batch_size
100: 句子的长度
65: vocab_size,相当于一个类别预测
-
3.2.3 随机采样
- 为什么需要随机采样呢?
- 这是因为在做文本生成的时候,在每一个位置上都会得到一个概率分布,采用随机采样的方法来基于这个概率分布来生成一段话
- tf.random.categorical:用来随机sample
- tf.squeeze:用来操作维度,减少一个维度
- 为什么需要随机采样呢?
# random sampling.
# greedy(贪心策略), random(随机策略).
sample_indices = tf.random.categorical(
logits = example_batch_predictions[0], num_samples = 1)
print(sample_indices)
# (100, 65) -> (100, 1)
sample_indices = tf.squeeze(sample_indices, axis = -1)
print(sample_indices)
运行结构:
tf.Tensor(
[[52]
[25]
[ 3]
[21]
[43]
[45]
[ 3]
[51]
[56]
[26]
[25]
[27]
[26]
[58]
[25]
[34]
[42]
[48]
[27]
[19]
[24]
[45]
[12]
[44]
[39]
[59]
[26]
[ 5]
[48]
[48]
[23]
[12]
[28]
[58]
[30]
[33]
[15]
[22]
[24]
[ 1]
[19]
[ 7]
[58]
[32]
[ 1]
[62]
[ 5]
[26]
[23]
[51]
[58]
[30]
[16]
[61]
[31]
[17]
[16]
[13]
[26]
[11]
[31]
[63]
[ 0]
[ 0]
[45]
[29]
[54]
[47]
[ 8]
[ 2]
[35]
[ 2]
[39]
[32]
[54]
[41]
[47]
[33]
[56]
[27]
[27]
[28]
[13]
[55]
[62]
[51]
[25]
[64]
[53]
[15]
[51]
[ 3]
[21]
[62]
[37]
[17]
[64]
[ 9]
[30]
[17]], shape=(100, 1), dtype=int64)
tf.Tensor(
[52 25 3 21 43 45 3 51 56 26 25 27 26 58 25 34 42 48 27 19 24 45 12 44
39 59 26 5 48 48 23 12 28 58 30 33 15 22 24 1 19 7 58 32 1 62 5 26
23 51 58 30 16 61 31 17 16 13 26 11 31 63 0 0 45 29 54 47 8 2 35 2
39 32 54 41 47 33 56 27 27 28 13 55 62 51 25 64 53 15 51 3 21 62 37 17
64 9 30 17], shape=(100,), dtype=int64)
- 3.2.4 打印输入输出
print("Input: ", repr("".join(idx2char[input_example_batch[0]])))
print()
print("Output: ", repr("".join(idx2char[target_example_batch[0]])))
print()
print("Predictions: ", repr("".join(idx2char[sample_indices])))
运行结果:
Input: 'ture\nTo mingle faith with him! Undone! undone!\nIf I might die within this hour, I have lived\nTo die '
Output: 'ure\nTo mingle faith with him! Undone! undone!\nIf I might die within this hour, I have lived\nTo die w'
Predictions: "nM$Ieg$mrNMONtMVdjOGLg?fauN'jjK?PtRUCJL G-tT x'NKmtRDwSEDAN;Sy\n\ngQpi.!W!aTpciUrOOPAqxmMzoCm$IxYEz3RE"
prediction是乱码是因为model还没有进行训练
- 3.2.5 定义损失函数,对model进行compile
def loss(labels, logits):
return keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
model.compile(optimizer = 'adam', loss = loss)
example_loss = loss(target_example_batch, example_batch_predictions)
print(example_loss.shape)
print(example_loss.numpy().mean())
运行结果:
(64, 100)
4.1839275
- 3.2.6 训练模型
output_dir = "./text_generation_checkpoints"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epoch}')
checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath = checkpoint_prefix,
save_weights_only = True)
epochs = 100
history = model.fit(seq_dataset, epochs = epochs,
callbacks = [checkpoint_callback])
- 3.2.7 查看最后保存的模型
tf.train.latest_checkpoint(output_dir)
-
3.3 采样生成文本
- 3.3.1 载入模型
model2 = build_model(vocab_size,
embedding_dim,
rnn_units,
batch_size = 1)
model2.load_weights(tf.train.latest_checkpoint(output_dir))
model2.build(tf.TensorShape([1, None]))
# 文本生成的流程
# start ch sequence A,
# A -> model -> b
# A.append(b) -> B
# B(Ab) -> model -> c
# B.append(c) -> C
# C(Abc) -> model -> ...
model2.summary()
运行结果:
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_2 (Embedding) (1, None, 256) 16640
_________________________________________________________________
simple_rnn_2 (SimpleRNN) (1, None, 1024) 1311744
_________________________________________________________________
dense_2 (Dense) (1, None, 65) 66625
=================================================================
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________
- 3.3.2文本生成:
- tf.expand_dims: 增加维度
def generate_text(model, start_string, num_generate = 1000):
input_eval = [char2idx[ch] for ch in start_string]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
model.reset_states()
for _ in range(num_generate):
# 1. model inference -> predictions
# 2. sample -> ch -> text_generated.
# 3. update input_eval
# predictions : [batch_size, input_eval_len, vocab_size]
predictions = model(input_eval)
# predictions : [input_eval_len, vocab_size]
predictions = tf.squeeze(predictions, 0)
# predicted_ids: [input_eval_len, 1]
# a b c -> b c d
predicted_id = tf.random.categorical(
predictions, num_samples = 1)[-1, 0].numpy()
text_generated.append(idx2char[predicted_id])
# s, x -> rnn -> s', y
input_eval = tf.expand_dims([predicted_id], 0)
return start_string + ''.join(text_generated)
new_text = generate_text(model2, "All: ")
print(new_text)
运行结果:
All: gentle: And it is, to water of will'd-pullens in the enemy?
On if thou be gone to the cropk.
Clown:
I will.
MARIANA:
Give me thyself was shumber;
To invuinate, as then I recking
To revout then unknown reberish weeds,
As you have true lips as believe us, good sir, witnession great unnot
God'I are murderer, for I will unselt the wallou preventest my bones
Of whom I know, or hatisage of thee?
GRESO:
Ar, but your stuffic ound thou exerving, braved or now? what overme of wrath; to their inferring thousand's tender, that I should private in Romes but of Ingles
That you 'darished recompensy in perfucuted-word:
To neity, give me home:
The true so weary Grevant:
Might there better they, when I lighted mornce, tell me, hither towards here but dead, this Aumerave be married.
ROMEO:
Upon the body, look, thy wisdom, I must beseech your ears, above:
In such arging from Oxford sho!
GLOUCESTER:
Nurse:
Good marvengrannith that borness, by set must arrays, far you would sit by: thou but he should
可以看到,这次的运行结果比没有训练的时候的结果要好,但是总体来看效果并不是很好,大部分的词还不算是词语。
但是没有关系,接下来,我们来看一个更为强大的循环神经网络:LSTM。