WGAN-GP学习笔记

738 阅读2分钟

GAN训练中存在的问题

判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握,甚至在同一轮训练的前后不同阶段这个火候都可能不一样,所以GAN才那么难训练.

梯度消失

原始的GAN训练使用Binary Cross Entropy函数,二元交叉熵函数只能判断生成的图片是否是我们想要的,无法判断判别器的分布跟真实分布之间的距离。二元交叉熵在某些地方会导致梯度消失

模式崩塌

模式崩塌(Model Collapse)是指GAN训练时可能出现的一种现象。给定了一个z,当z发生变化的时候,对应的G(z)没有变化,那么在这个局部,GAN就发生了mode collapse,也就是不能产生连续变化的样本,从而样本多样性不足。模式崩塌的原因是判别器过拟合了。

过生成

过生成也是另一种形式的模式崩塌。过生成指的是生成的图片很奇怪,判别器无法判断真假

总结

  • 在GAN训练过程中难以收敛,这很大程度归咎于梯度消失
  • Binary Cross Entropy本身具有梯度消失的特性(梯度很小或趋于0),使得生成器和判别器没有办法根据梯度来更新参数
  • 因此需要更新度量标准,比如将BCE改为MSE

Wasserstein距离

请看我之前的文章,Wasserstein距离 WGAN将GAN的目标函数进行了替换

V(G,D)=maxD1Lipschitz{Exdata[Data]ExPG[D(x)]} V(G,D) = \max_{D\in 1-Lipschitz}\left\{ E_{x\sim data}[Data] - E_{x\sim P_G}[D(x)] \right\}

其中,1-Lipschitz是一个条件,只有满足了这个条件才符合Wasserstein距离,梯度的绝对值小于1

截屏2023-01-06 上午9.52.17.png

Gradient Penalty

WGAN-GP对weight-clipping(梯度裁剪)进行了优化,改用Gradient Penalty梯度惩罚的方法,提升了训练稳定性。

截屏2023-01-06 下午3.26.21.png 损失函数把log2给拿掉了,同时要满足1-Lipschitz条件。

截屏2023-01-06 下午3.28.52.png

GP的实现

  • 在高维空间进行梯度惩罚是十分困难的,因此采用退而求其次的方法,直接在生成分布和真实数据分布之间进行梯度惩罚。
  • 真实数据分布和生成器生成的数据分布我们是不知道的,因此作者在论文中使用一种插值的方法来估计生成的数据分布到真实分布的路径。图片插值看似没有意义,但在没有办法估计真实路径的情况下,这种插值至少可以鼓励在正确的地方进行梯度惩罚

def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * epsilon + fake * (1 - epsilon)

    # 计算 critic score
    mixed_scores = critic(interpolated_images)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty