GAN 模型

123 阅读5分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第20天,点击查看活动详情

GAN

生成模型 vs 判别模型

机器学习过程中是利用有限样本尽可能准确的估计出后验概率 P(yx)P(y|x),通俗讲从属性 xx 去预测 yy,有两种基本策略来进行解决。

对于生成模型,首先对联合概率分布 P(x,y)P(x,y) 进行建模,再由此获得 P(yx)P(y|x)。涉及公式:P(yx)=P(x,y)P(x)P(y \mid x)=\frac{P(x, y)}{P(x)},经典模型包括:朴素贝叶斯,LDA等

对于判别模型,直接对 P(yx)P(y|x) 进行建模,获取值。经典模型包括:线性模型等。

举个例子:假设现在有一头🐏,怎样对它进行识别?

判别式:要确定一个羊是山羊还是绵羊,用判别模型的方法是从历史数据中学习到模型,然后通过提取这只羊的特征来预测出这只羊是山羊的概率,是绵羊的概率。

生成式:利用生成模型是根据山羊的特征首先学习出一个山羊的模型,然后根据绵羊的特征学习出一个绵羊的模型,然后从这只羊中提取特征,放到山羊模型中看概率是多少,在放到绵羊模型中看概率是多少,哪个大就是哪个。

总结讲:处理分类问题为例,生成模型处于宏观的角度,划分出每个类别的范围。判别模型处于微观角度,找到各个类别之间差异,通过差异去寻找。

img

训练过程

GAN,即生成对抗网络,主要包含两个模块:生成器(Generative Model)和判别器(Discriminative Model)。

生成器的主要任务是学习真实图片集,从而使得自己生成的图片更接近于真实图片,以“骗过”判别器。

判别器的主要任务是找出出生成器生成的图片,区分其与真实图片的不同,进行真假判别。

举个例子说明一下过程:

假设给你一张图片,我希望识别出一张钞票。

第一次:

image.png

生成模型输出左边一张照片,判别器进行判断,判断左边照片是 FAKE 。

第二次:

image.png

经过第一次之后,生成模型输出左边照片,判别器进行判断,判断左边照片是 FAKE 。

经过不断的迭代过程,优化判别器和生成器,最后:

image.png

生成器输出的图片,判别器无法判定为 FAKE,整个模型训练成功。

在整个迭代过程中,生成器不断努力让生成的图片越来越像真的,而判别器不断努力识别出图片的真假。这类似生成器与判别器之间的博弈,随着反复迭代,最终二者达到了平衡:【生成器生成的图片非常接近于真实图片,而判别器已经很难识别出真假图片的不同了。其表现是对于真假图片,判别器的概率输出都接近 0.5。】

结构

GAN 结构图如下:

image.png

在具体训练过程中,生成器 (Generative Model)与判别器(Discriminative Model)交替训练:

  • 首先,固定生成器,基于随机向量 zz 模拟出 G(z)G(z) 作为负样本,并从真实数据中采样得到正样本 xx,然后将正负样本输入给判别器,进行二分类预测,最后利用其二分类交叉熵损失更新判别器参数;
  • 之后固定判别器,目的是优化生成器,对于生成器,为了尽可能欺骗判别器,即尽量让判别器将生成的负样本判为正样本,一般考虑以最大化生成样本的判别概率为目标来优化。

Two neural networks contest with each other in a game (in the form of a zero-sum game, where one agent's gain is another agent's loss).

优化目标函数:

minGmaxDV(G,D)=minGmaxDExpdata [logD(x)]+Ezpz[log(1D(G(z))]\min _{G} \max _{D} V(G, D)=\min _{G} \max _{D} \mathrm{E}_{x \sim p_{\text {data }}}[\log D(x)]+\mathrm{E}_{z \sim p_{z}}[\log (1-D(G(z))]

说明公式重要部分;

  • xx :真实数据样本。
  • zz :输入G网络的噪声,随机分布采集的样本
  • GG:生成器 Generative 。
  • DD:判别器 Discriminative 。
  • G(z)G(z):输入噪声,生成一个样本。
  • D(x)D(x):判断真实图片是否真实的概率,这个值越接近 1 越好。
  • D(G(z))D(G(z)):判断G生成的图片的是否真实的概率。

DD 的能力越强,D(x)D(x) 应该越大,D(G(x))D(G(x)) 应该越小。这时 V(D,G)V(D,G) 会变大,因此求最大值 maxmax,也就是说我们希望判别器强大,对生成数据与真实数据进行分类识别。

GG 应该希望自己生成的图片越接近真实越好GG 希望 D(G(z))D(G(z)) 尽可能得大,能够使 DD 进行误判置 1 ,这时 V(D,G)V(D, G) 会变小,因此求最小值 minmin

梯度训练过程描述:

image.png

当我们训练 DD 时候,求 V(G,D)V(G,D) 最大值,选择上升(ascending)梯度去更新模型。

当我们训练 GG 时候,求 V(G,D)V(G,D) 最小值,选择下降(descending)梯度去更新模型。

变体

CycleGAN

之前有着一个很火的功能:将图片风格变成漫画风格。CycleGAN 就可以初步实现这个神奇的功能:图片的风格迁移,具体例子如下图:

img

CycleGAN的特点在于能够在源域和目标域之间,无须建立训练数据间一对一的映射,就可实现这种迁移。

原理结构图如下:

img

GG:从 X 到 Y 的生成器

FF:从 Y 到 X 的生成器

D(x)D(x):鉴别 X 的判别器

D(y)D(y):鉴别 Y 的判别器

当输入 xx 经过 GG 生成器,生成一行 y^\hat y 图片,送入 D(y)D(y) 判别器进行判断,再送入到 FF 生成器,生成一个 x^\hat x 图片,将 xxx^\hat x 进行比较求 loss,另一个过程是一样的,可以看图理解。

最重要一点:cycle-consistency loss,主要目的:用数据集中其他的图来检验生成器,防止 GGFF 过拟合,比如想把一个小狗照片转化成梵高风格,如果没有 cycle-consistency loss,生成器可能会生成一张梵高真实画作来骗过D(x)D(x),而无视输入的小狗。

相关公式展示:

G:XYL(G,DY,X,Y)=Eypdata (y)[logDY(y)]+Expdata (x)[log(1DY(G(x)))]F:YXL(F,DX,Y,X)=Expdata (x)[logDX(x)]+Eypdata (y)[log(1DX(F(y)))Lcyc(G,F)=Expdata (x)[F(G(x))x1]+Eypdata (y)[G(F(y))y1L(G,F,Dx,Dy)=L(G,DY,X,Y)+L(F,DX,Y,X)+λLcyc (G,F)\begin{array}{l} G: X \rightarrow Y \\ \mathscr{L}\left(G, D_{Y}, X, Y\right)=\mathbb{E}_{y \sim p_{\text {data }}(y)}\left[\log D_{Y}(y)\right]+\mathbb{E}_{x \sim p_{\text {data }}(x)}\left[\log \left(1-D_{Y}(G(x))\right)\right] \\ F: Y \rightarrow X \\ \mathscr{L}\left(F, D_{X}, Y, X\right)=\mathbb{E}_{x \sim p_{\text {data }}(x)}\left[\log D_{X}(x)\right]+\mathbb{E}_{y \sim p_{\text {data }}(y)}\left[\log \left(1-D_{X}(F(y))\right)\right. \\ \mathscr{L}_{c y c}(G, F)=\mathbb{E}_{x \sim p_{\text {data }}(x)}\left[\|F(G(x))-x\|_{1}\right]+\mathbb{E}_{y \sim p_{\text {data }}(y)}\left[\|G(F(y))-y\|_{1}\right. \\ \mathscr{L}\left(G, F, D_{x}, D_{y}\right)=\mathscr{L}\left(G, D_{Y}, X, Y\right)+\mathscr{L}\left(F, D_{X}, Y, X\right)+\lambda \mathscr{L}_{\text {cyc }}(G, F) \end{array}

注意点:λ\lambda 会控制 cycle-consistency loss,这意味着当 λ\lambda 越大,希望模型 xxx^\hat x 更加相识。

参考