送书 | AI插画师:如何用基于PyTorch的生成对抗网络生成动漫头像?

3,296 阅读19分钟
本文由 「AI前线」原创,原文链接:送书 | AI插画师:如何用基于PyTorch的生成对抗网络生成动漫头像?
作者|陈云
编辑|Natalie

AI 前线导读:”2016 年是属于 TensorFlow 的一年,凭借谷歌的大力推广,TensorFlow 占据了各大媒体的头条。2017 年年初,PyTorch 的横空出世吸引了研究人员极大的关注,PyTorch 简洁优雅的设计、统一易用的接口、追风逐电的速度和变化无方的灵活性给人留下深刻的印象。

本文节选自《深度学习框架 PyTorch 入门与实践》第 7 章,为读者讲解当前最火爆的生成对抗网络(GAN),带领读者从零开始实现一个动漫头像生成器,能够利用 GAN 生成风格多变的动漫头像。注意啦,文末有送书福利!”


生成对抗网络(Generative Adversarial Net,GAN)是近年来深度学习中一个十分热门的方向,卷积网络之父、深度学习元老级人物 LeCun Yan 就曾说过“GAN is the most interesting idea in the last 10 years in machine learning”。尤其是近两年,GAN 的论文呈现井喷的趋势,GitHub 上有人收集了各种各样的 GAN 变种、应用、研究论文等,其中有名称的多达数百篇。作者还统计了 GAN 论文发表数目随时间变化的趋势,如图 7-1 所示,足见 GAN 的火爆程度。

图 7-1 GAN 的论文数目逐月累加图


GAN 的原理简介

GAN 的开山之作是被称为“GAN 之父”的 Ian Goodfellow 发表于 2014 年的经典论文 Generative Adversarial Networks ,在这篇论文中他提出了生成对抗网络,并设计了第一个 GAN 实验——手写数字生成。

GAN 的产生来自于一个灵机一动的想法:

“What I cannot create, I do not understand.”(那些我所不能创造的,我也没有真正地理解它。)
—Richard Feynman

类似地,如果深度学习不能创造图片,那么它也没有真正地理解图片。当时深度学习已经开始在各类计算机视觉领域中攻城略地,在几乎所有任务中都取得了突破。但是人们一直对神经网络的黑盒模型表示质疑,于是越来越多的人从可视化的角度探索卷积网络所学习的特征和特征间的组合,而 GAN 则从生成学习角度展示了神经网络的强大能力。GAN 解决了非监督学习中的著名问题:给定一批样本,训练一个系统能够生成类似的新样本

生成对抗网络的网络结构如图 7-2 所示,主要包含以下两个子网络。

  • 生成器(generator):输入一个随机噪声,生成一张图片。
  • 判别器(discriminator):判断输入的图片是真图片还是假图片。

图 7-2 生成对抗网络结构图

训练判别器时,需要利用生成器生成的假图片和来自真实世界的真图片;训练生成器时,只用噪声生成假图片。判别器用来评估生成的假图片的质量,促使生成器相应地调整参数。

生成器的目标是尽可能地生成以假乱真的图片,让判别器以为这是真的图片;判别器的目标是将生成器生成的图片和真实世界的图片区分开。可以看出这二者的目标相反,在训练过程中互相对抗,这也是它被称为生成对抗网络的原因。

上面的描述可能有点抽象,让我们用收藏齐白石作品(齐白石作品如图 7-3 所示)的书画收藏家和假画贩子的例子来说明。假画贩子相当于是生成器,他们希望能够模仿大师真迹伪造出以假乱真的假画,骗过收藏家,从而卖出高价;书画收藏家则希望将赝品和真迹区分开,让真迹流传于世,销毁赝品。这里假画贩子和收藏家所交易的画,主要是齐白石画的虾。齐白石画虾可以说是画坛一绝,历来为世人所追捧。

图 7-3 齐白石画虾图真迹

在这个例子中,一开始假画贩子和书画收藏家都是新手,他们对真迹和赝品的概念都很模糊。假画贩子仿造出来的假画几乎都是随机涂鸦,而书画收藏家的鉴定能力很差,有不少赝品被他当成真迹,也有许多真迹被当成赝品。

首先,书画收藏家收集了一大堆市面上的赝品和齐白石大师的真迹,仔细研究对比,初步学习了画中虾的结构,明白画中的生物形状弯曲,并且有一对类似钳子的“螯足”,对于不符合这个条件的假画全部过滤掉。当收藏家用这个标准到市场上进行鉴定时,假画基本无法骗过收藏家,假画贩子损失惨重。但是假画贩子自己仿造的赝品中,还是有一些蒙骗过关,这些蒙骗过关的赝品中都有弯曲的形状,并且有一对类似钳子的“螯足”。于是假画贩子开始修改仿造的手法,在仿造的作品中加入弯曲的形状和一对类似钳子的“螯足”。除了这些特点,其他地方例如颜色、线条都是随机画的。假画贩子制造出的第一版赝品如图 7-4 所示。

图 7-4 假画贩子制造的第一版赝品

当假画贩子把这些画拿到市面上去卖时,很容易就骗过了收藏家,因为画中有一只弯曲的生物,生物前面有一对类似钳子的东西,符合收藏家认定的真迹的标准,所以收藏家就把它当成真迹买回来。随着时间的推移,收藏家买回越来越多的假画,损失惨重,于是他又闭门研究赝品和真迹之间的区别,经过反复比较对比,他发现齐白石画虾的真迹中除了有弯曲的形状,虾的触须蔓长,通身作半透明状,并且画的虾的细节十分丰富,虾的每一节之间均呈白色状。

收藏家学成之后,重新出山,而假画贩子的仿造技法没有提升,所制造出来的赝品被收藏家轻松识破。于是假画贩子也开始尝试不同的画虾手法,大多都是徒劳无功,不过在众多尝试之中,还是有一些赝品骗过了收藏家的眼睛。假画贩子发现这些仿制的赝品触须蔓长,通身作半透明状,并且画的虾的细节十分丰富,如图 7-5 所示。于是假画贩子开始大量仿造这种画,并拿到市面上销售,许多都成功地骗过了收藏家。

图 7-5 假画贩子制造的第二版赝品

收藏家再度损失惨重,被迫关门研究齐白石的真迹和赝品之间的区别,学习齐白石真迹的特点,提升自己的鉴定能力。就这样,通过收藏家和假画贩子之间的博弈,收藏家从零开始慢慢提升了自己对真迹和赝品的鉴别能力,而假画贩子也不断地提高自己仿造齐白石真迹的水平。收藏家利用假画贩子提供的赝品,作为和真迹的对比,对齐白石画虾真迹有了更好的鉴赏能力;而假画贩子也不断尝试,提升仿造水平,提升仿造假画的质量,即使最后制造出来的仍属于赝品,但是和真迹相比也很接近了。收藏家和假画贩子二者之间互相博弈对抗,同时又不断促使着对方学习进步,达到共同提升的目的。

在这个例子中,假画贩子相当于一个生成器,收藏家相当于一个判别器。一开始生成器和判别器的水平都很差,因为二者都是随机初始化的。训练过程分为两步交替进行,第一步是训练判别器(只修改判别器的参数,固定生成器),目标是把真迹和赝品区分开;第二步是训练生成器(只修改生成器的参数,固定判别器),为的是生成的假画能够被判别器判别为真迹(被收藏家认为是真迹)。这两步交替进行,进而分类器和判别器都达到了一个很高的水平。训练到最后,生成器生成的虾的图片(如图 7-6 所示)和齐白石的真迹几乎没有差别。

图 7-6 生成器生成的虾

下面我们来思考网络结构的设计。判别器的目标是判断输入的图片是真迹还是赝品,所以可以看成是一个二分类网络,参考第 6 章中 Dog vs. Cat 的实验,我们可以设计一个简单的卷积网络。生成器的目标是从噪声中生成一张彩色图片,这里我们采用广泛使用的 DCGAN(Deep Convolutional Generative Adversarial Networks)结构,即采用全卷积网络,其结构如图 7-7 所示。网络的输入是一个 100 维的噪声,输出是一个 3×64×64 的图片。这里的输入可以看成是一个 100×1×1 的图片,通过上卷积慢慢增大为 4×4、8×8、16×16、32×32 和 64×64。上卷积,或称转置卷积,是一种特殊的卷积操作,类似于卷积操作的逆运算。当卷积的 stride 为 2 时,输出相比输入会下采样到一半的尺寸;而当上卷积的 stride 为 2 时,输出会上采样到输入的两倍尺寸。这种上采样的做法可以理解为图片的信息保存于 100 个向量之中,神经网络根据这 100 个向量描述的信息,前几步的上采样先勾勒出轮廓、色调等基础信息,后几步上采样慢慢完善细节。网络越深,细节越详细。

图 7-7 DCGAN 中生成器网络结构图

在 DCGAN 中,判别器的结构和生成器对称:生成器中采用上采样的卷积,判别器中就采用下采样的卷积,生成器是根据噪声输出一张 64×64×3 的图片,而判别器则是根据输入的 64×64×3 的图片输出图片属于正负样本的分数(概率)。


用 GAN 生成动漫头像

本节将用 GAN 实现一个生成动漫人物头像的例子。在日本的技术博客网站上 有个博主(估计是一位二次元的爱好者),利用 DCGAN 从 20 万张动漫头像中学习,最终能够利用程序自动生成动漫头像,生成的图片效果如图 7-8 所示。源程序是利用 Chainer 框架实现的,本节我们尝试利用 PyTorch 实现。

图 7-8 DCGAN 生成的动漫头像

原始的图片是从网站中爬取的,并利用 OpenCV 从中截取头像,处理起来比较麻烦。这里我们使用知乎用户何之源爬取并经过处理的 5 万张图片。可以从本书配套程序的 README.MD 的百度网盘链接下载所有的图片压缩包,并解压缩到指定的文件夹中。需要注意的是,这里图片的分辨率是 3×96×96,而不是论文中的 3×64×64,因此需要相应地调整网络结构,使生成图像的尺寸为 96。

我们首先来看本实验的代码结构。

接着来看 model.py 中是如何定义生成器的。

可以看出生成器的搭建相对比较简单,直接使用 nn.Sequential 将上卷积、激活、池化等操作拼接起来即可,这里需要注意上卷积 ConvTransposed2d 的使用。当 kernel size 为 4、stride 为 2、padding 为 1 时,根据公式 H_out=(H_in-1)*stride-2*padding+kernel_size,输出尺寸刚好变成输入的两倍。最后一层采用 kernel size 为 5、stride 为 3、padding 为 1,是为了将 32×32 上采样到 96×96,这是本例中图片的尺寸,与论文中 64×64 的尺寸不一样。最后一层用 Tanh 将输出图片的像素归一化至 -1~1,如果希望归一化至 0~1,则需使用 Sigmoid。接着我们来看判别器的网络结构。

可以看出判别器和生成器的网络结构几乎是对称的,从卷积核大小到 padding、stride 等设置,几乎一模一样。例如生成器的最后一个卷积层的尺度是(5,3,1),判别器的第一个卷积层的尺度也是(5,3,1)。另外,这里需要注意的是生成器的激活函数用的是 ReLU,而判别器使用的是 LeakyReLU,二者并无本质区别,这里的选择更多是经验总结。每一个样本经过判别器后,输出一个 0~1 的数,表示这个样本是真图片的概率。在开始写训练函数前,先来看看模型的配置参数。

这些只是模型的默认参数,还可以利用 Fire 等工具通过命令行传入,覆盖默认值。另外,我们也可以直接使用 opt.attr,还可以利用 IDE/IPython 提供的自动补全功能,十分方便。这里的超参数设置大多是照搬 DCGAN 论文的默认值,作者经过大量实验,发现这些参数能够更快地训练出一个不错的模型。

当我们下载完数据之后,需要将所有图片放在一个文件夹,然后将该文件夹移动至 data 目录下(请确保 data 下没有其他的文件夹)。这种处理方式是为了能够直接使用 torchvision 自带的 ImageFolder 读取图片,而不必自己写 Dataset。数据读取与加载的代码如下:

可见,用 ImageFolder 配合 DataLoader 加载图片十分方便。

在进行训练之前,我们还需要定义几个变量:模型、优化器、噪声等。


在加载预训练模型时,最好指定 map_location。因为如果程序之前在 GPU 上运行,那么模型就会被存成 torch.cuda.Tensor,这样加载时会默认将数据加载至显存。如果运行该程序的计算机中没有 GPU,加载就会报错,故通过指定 map_location 将 Tensor 默认加载入内存中,待有需要时再移至显存中。

下面开始训练网络,训练步骤如下。

(1)训练判别器。

  • 固定生成器
  • 对于真图片,判别器的输出概率值尽可能接近 1
  • 对于生成器生成的假图片,判别器尽可能输出 0

(2)训练生成器。

  • 固定判别器
  • 生成器生成图片,尽可能让判别器输出 1

(3)返回第一步,循环交替训练。

这里需要注意以下几点。

  • 训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
  • 在训练判别器时,需要对生成器生成的图片用 detach 操作进行计算图截断,避免反向传播将梯度传到生成器中。因为在训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
  • 在训练分类器时,需要反向传播两次,一次是希望把真图片判为 1,一次是希望把假图片判为 0。也可以将这两者的数据放到一个 batch 中,进行一次前向传播和一次反向传播即可。但是人们发现,在一个 batch 中只包含真图片或只包含假图片的做法最好。
  • 对于假图片,在训练判别器时,我们希望它输出为 0;而在训练生成器时,我们希望它输出为 1。因此可以看到一对看似矛盾的代码:error_d_fake = criterion(fake_output, fake_labels) 和 error_g = criterion(fake_output, true_labels)。其实这也很好理解,判别器希望能够把假图片判别为 fake_label,而生成器则希望能把它判别为 true_label,判别器和生成器互相对抗提升。

接下来就是一些可视化的代码。每次可视化使用的噪声都是固定的 fix_noises,因为这样便于我们比较对于相同的输入,生成器生成的图片是如何一步步提升的。另外,由于我们对输入的图片进行了归一化处理(-1~1),在可视化时则需要将它还原成原来的 scale(0~1) 。

除此之外,还提供了一个函数,能加载预训练好的模型,并利用噪声随机生成图片。

完整的代码请参考本书的附带样例代码 chapter7/AnimeGAN。参照 README.MD 中的指南配置环境,并准备好数据,而后用如下命令即可开始训练:

如果使用 visdom 的话,此时打开 http://[your ip]:8097 就能看到生成的图像。

训练完成后,我们可以利用生成网络随机生成动漫头像,输入命令如下:


实验结果分析

实验结果如图 7-9 所示,分别是训练 1 个、10 个、20 个、30 个、40 个、200 个 epoch 之后神经网络生成的动漫头像。需要注意的是,每次生成器输入的噪声都是一样的,所以我们可以对比在相同的输入下,生成图片的质量是如何慢慢改善的。

刚开始生成的图像比较模糊(1 个 epoch),但是可以看出图像已经有面部轮廓。

继续训练 10 个 epoch 之后,生成的图多了很多细节信息,包括头发、颜色等,但是总体还是很模糊。

训练 20 个 epoch 之后,细节继续完善,包括头发的纹理、眼睛的细节等,但还是有不少涂抹的痕迹。

训练到第 40 个 epoch 时,已经能看出明显的面部轮廓和细节,但还是有涂抹现象,并且有些细节不够合理,例如眼睛一大一小,面部的轮廓扭曲严重。

当训练到 200 个 epoch 之后,图片的细节已经十分完善,线条更流畅,轮廓更清晰,虽然还有一些不合理之处,但是已经有不少图片能够以假乱真了。

图 7-9 GAN 生成的动漫头像

类似的生成动漫头像的项目还有“用 DRGAN 生成高清的动漫头像”,效果如图 7-10 所示。但遗憾的是,由于论文中使用的数据涉及版权问题,未能公开。这篇论文的主要改进包括使用了更高质量的图片数据和更深、更复杂的模型。

图 7-10 用 DRGAN 生成的动漫头像

本章讲解的样例程序还可以应用到不同的生成图片场景中,只要将训练图片改成其他类型的图片即可,例如 LSUN 客房图片集、MNIST 手写数据集或 CIFAR10 数据集等。事实上,上述模型还有很大的改进空间。在这里,我们使用的全卷积网络只有四层,模型比较浅,而在 ResNet 的论文发表之后,也有不少研究者尝试在 GAN 的网络结构中引入 Residual Block 结构,并取得了不错的视觉效果。感兴趣的读者可以尝试将示例代码中的单层卷积修改为 Residual Block,相信可以取得不错的效果。

近年来,GAN 的一个重大突破在于理论研究。论文 Towards Principled Methods for Training Generative Adversarial Networks 从理论的角度分析了 GAN 为何难以训练,作者随后在另一篇论文 Wasserstein GAN 中针对性地提出了一个更好的解决方案。但是 Wasserstein GAN 这篇论文在部分技术细节上的实现过于随意,所以随后又有人有针对性地提出 Improved Training of Wasserstein GANs,更好地训练 WGAN。后面两篇论文分别用 PyTorch 和 TensorFlow 实现,代码可以从 GitHub 上搜索到。笔者当初也尝试用 100 行左右的代码实现了 Wasserstein GAN,感兴趣的读者可以去了解 。

随着 GAN 研究的逐渐成熟,人们也尝试把 GAN 用于工业实际问题之中,而在众多相关论文中,最令人印象深刻的就是 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ,论文中提出了一种新的 GAN 结构称为 CycleGAN。CycleGAN 利用 GAN 实现风格迁移、黑白图像彩色化,以及马和斑马相互转化等,效果十分出众。论文的作者用 PyTorch 实现了所有代码,并开源在 GitHub 上,感兴趣的读者可以自行查阅。

本章主要介绍 GAN 的基本原理,并带领读者利用 GAN 生成动漫头像。GAN 有许多变种,GitHub 上有许多利用 PyTorch 实现的各种 GAN,感兴趣的读者可以自行查阅。

作者介绍

陈云,Python 程序员、Linux 爱好者和 PyTorch 源码贡献者。主要研究方向包括计算机视觉和机器学习。“2017 知乎看山杯机器学习挑战赛”一等奖,“2017 天池医疗 AI 大赛”第八名。热衷于推广 PyTorch,并有丰富的使用经验,活跃于 PyTorch 论坛和知乎相关板块。

福利!福利!我们将给 AI 前线的粉丝送出《深度学习框架 PyTorch 入门与实践》纸质书籍 10 本!在本文下方留言给出你想要这本书的理由,我们会邀请你加入赠书群,本次获奖名单由抽奖小程序随机抽取,2 月 6 日(周二)上午 10 点开奖,获奖者每人获得一本。另附京东购买地址,戳「阅读原文」!

更多干货内容,可关注AI前线,ID:ai-front,后台回复「AI」、「TF」、「大数据」可获得《AI前线》系列PDF迷你书和技能图谱。