前言
使用 tensorflow 2.10 实现 AutoEncoder 模型,还原生成 MNIST 图片。
数据
- 从 TensorFlow 的 Keras 库中加载 MNIST 数据集。MNIST 数据集包含了手写数字的灰度图像。
load_data()函数会返回两个元组,分别包含了训练数据和测试数据。因为本文只关心图像本身,不关心标签,所以我们用下划线_来表示标签。 - 将加载的图像数据进行标准化,也就是将像素值缩放到
0到1之间。MNIST 图像的像素值范围通常是0到255之间的整数,所以我们将每个像素值除以255.0来将其缩放到 0 到 1 之间的浮点数,有助于神经网络更好地学习图像特征,也有利于网络的快速收敛。 - MNIST 图像数据中每个图像的尺寸为
28x28像素,但在训练神经网络时,我们需要将图像数据展平成一维向量,因此将训练和测试数据的形状从(B, 28, 28)变为(B, 784)。
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
模型
我们定义了一个简单的 AutoEncoder 模型,其中包括一个编码器和一个解码器,并设置了模型的优化器和损失函数。具体如下:
- 将模型
AutoEncoder的输入大小设置为(784,),对应于上面转化后的每个图像展平后的像素数。 - 然后定义了编码器部分,它是一个包含一个全连接层。该层有
16个神经元,使用ReLU作为激活函数。编码器主要是为了对图像进行压缩和编码。 - 接着定义了解码器部分,它也是一个全连接层,它的输出维度与编码器的输入数据维度相同
(784维)。解码器使用Sigmoid激活函数,以确保输出值在 [0, 1] 范围内,因为我们输入的图像数据值在这个范围中。解码器的作用就是将编码器的特征尽量还原回原图像。 - 最后编译
AutoEncoder模型,使用Adam作为优化器,使用均方误差(Mean Squared Error,MSE)作为损失函数。
input_dim = 784
input_layer = tf.keras.layers.Input(shape=(input_dim,))
encoder = tf.keras.layers.Dense(16, activation='relu')(input_layer)
decoder = tf.keras.layers.Dense(input_dim, activation='sigmoid')(encoder)
autoencoder = tf.keras.models.Model(inputs=input_layer, outputs=decoder)
autoencoder.compile(optimizer='adam', loss='mse')
训练
- 对
AutoEncoder模型进行10个次循环训练,每次循环只进行一个epoch训练,这里是为将每一个epoch的训练结果进行保存,便于后期观察训练效果。并使用批量大小为256的样本数据进行训练。训练数据和验证数据都是x_train,这是因为AutoEncoder的目标是学习如何将输入数据编码并尽量还原回去。 - 在每个训练迭代之后,使用训练好的
AutoEncoder模型对测试数据x_test进行预测,得到了decoded_img,即重构后的图像。 - 在每次循环中取部分个测试图片原图和预测图,绘制成对比的效果图并保存到本地。因为有
10次循环,所以最终会生成10张图片。
for i in range(10):
autoencoder.fit(x_train, x_train, epochs=1, batch_size=256, validation_data=(x_test, x_test))
decoded_img = autoencoder.predict(x_test)
n = 10
plt.figure(figsize=(18, 4))
for j in range(n):
ax = plt.subplot(2, n, j + 1)
plt.imshow(x_test[j].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, j + 1 + n)
plt.imshow(decoded_img[j].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.tight_layout()
plt.savefig(f'example{i}.png')
训练日志打印:
235/235 [==============================] - 1s 2ms/step - loss: 0.0812 - val_loss: 0.0501
313/313 [==============================] - 0s 594us/step
235/235 [==============================] - 0s 2ms/step - loss: 0.0435 - val_loss: 0.0381
313/313 [==============================] - 0s 521us/step
...
235/235 [==============================] - 1s 2ms/step - loss: 0.0241 - val_loss: 0.0235
313/313 [==============================] - 0s 547us/step
235/235 [==============================] - 0s 2ms/step - loss: 0.0236 - val_loss: 0.0230
313/313 [==============================] - 0s 563us/step
效果
这里主要是将上面生成的 10 张图片合成一个 gif ,以 1 帧/秒的速率播放,可以看出随着模型的训练,模型生成的图片的还原程度逐步提升,如果增加模型架构的复杂度,生成效果还可以进一步提升。
image_files = [f'example{i}.png' for i in range(10)]
images = []
fig = plt.figure()
def update(frame):
img = plt.imread(image_files[frame])
plt.imshow(img)
plt.axis('off')
images.append([plt.imshow(img, animated=True)])
ani = animation.FuncAnimation(fig, update, frames=len(image_files), repeat=False)
output_gif = "output.gif"
ani.save(output_gif, writer='pillow', fps=1)
效果展示: