持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第29天,点击查看活动详情
一、 选题背景
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。WGAN算法是对GAN的改进,通过生成模型和判别模型的互相博弈学习优化,产生很好的输出,其在图像生成上有着广泛的应用。
本实验使用WGAN算法,实现生成动漫头像的目的。
二、 开发环境
在构建模型过程中,使用python3.8.8版本,Tensorflow是2.3.1版本。
导入的库有:
import multiprocessing
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
import numpy as np
from PIL import Image
import glob
三、 数据处理
1. 数据来源
本实验使用的是一组二次元头像数据集,数据集爬取自动漫图库网站konachan.net - Konachan.com Anime Wallpapers,从爬取的图片中截取人物头像作为数据训练。
2. 数据特点
实验使用保存在本地的动漫头像图片,共51223张,每张图片像素均为96x96。
3. 数据处理
实验过程中将数据随机打散,将图片的像素点的值变成[−1,1]之间。批量大小设置为512,预处理后的图像shape为64643。
四、 模型设计
1. 生成器
生成网络G由3个转置卷积层单元堆叠而成,实现特征图高宽的层层放大,特征图通道数的层层减少。首先对长度为100的隐藏向量z 进行reshape操作,并依序通过转置卷积层,放大高宽维度,减少通道数维度,最后得到高宽为64,通道数为3的彩色图片。每个转置卷积层中间插入BN层来提高训练稳定性。具体设计如下表:
| 输入 | 全连接层 | reshape | 转置卷积层1 | 转置卷积层2 | 转置卷积层3 | 输出 |
|---|---|---|---|---|---|---|
| 100维的随机噪声z | 神经元:33512 | 33512→(3,3,512) | 卷积核大小3*3,卷积核个数256,步长3 | 卷积核大小5*5,卷积核个数128,步长2 | 卷积核大小4*4,卷积核个数3,步长3 | 64643的图片 |
2. 判别器
判别网络D与普通的分类网络相同,接受大小为[ b , 64 , 64 , 3 ] 的图片张量,连续通过3个卷积层实现特征的层层提取 ,最后通过一个全连接层获得二分类任务的概率。
| 输入 | 卷积层1 | 卷积层2 | 卷积层3 | Flatten层 | 全连接层 | 输出 |
|---|---|---|---|---|---|---|
| 64643的图片 | 卷积核大小5*5,卷积核个数64,步长3 | 卷积核大小5*5,卷积核个数128,步长3 | 卷积核大小5*5,卷积核个数256,步长3 | 将多维数组一维化 | 神经元:1 | 1个输出节点,表示图片是真实图片的概率 |
五、 模型训练
1. 参数设置
2. 判别器误差函数
判别网络的训练目标是使得真实样本预测为真的概率接近于1,生成样本预测为真的概率接近于0。WGAN直接最大化真实样本的输出值,最小化生成样本的输出值。
将判别器的误差函数实现在d_loss_fn函数中,将所有真实样本标注为1,所有生成样本标注为0。d_loss_fn函数实现如下:
其中celoss_ones函数计算当前预测概率与标签1之间的交叉熵损失,代码如下:
celoss_zeros函数计算当前预测概率与标签0之间的交叉熵损失,代码如下:
gradient_penalty函数是WGAN-GP梯度惩罚函数,代码如下:
3. 生成器误差函数
由于真实样本与生成器无关,因此生成器误差函数只需要考虑最大化生成样本在判别器D的输出值。生成器的误差函数代码如下:
4. 训练及调优过程
训练时,首先创建生成网络和判别网络,并分别创建对应的优化器,优化器采用Adam优化器。交替训练生成器G和判别器D。代码如下:
六、 模型评测
每间隔100个Epoch,进行一次图片生成测试。随机采样隐向量,送入生成器生成图片,并保存为文件。
如下图所示,展示了WGAN-GP模型在训练过程中保存的生成图片样例,可以观察到,经过多次迭代后,大部分图片主体明确,色彩逼真,图片效果较为贴近数据集中真实的图片。同时也能发现仍有少量生成图片损坏,图片丰富度还不够。