GAN系列——BEGAN

53 阅读2分钟

简介

BEGAN 提出了一种新的均衡训练方法,该方法与从 Wasserstein 距离导出的损失相匹配,用于训练基于自动编码器的 GAN。该方法在训练期间平衡生成器和鉴别器。此外,它提供了一种新的近似收敛度量、快速稳定训练和高视觉质量的方法。推导出一种控制图像多样性和视觉质量之间权衡的方法。

典型的 GAN 试图直接匹配数据分布,但我们的方法旨在使用从 Wasserstein 距离导出的损失来匹配自动编码器损失分布。这是使用一个典型的GAN目标完成的,并添加了一个平衡项来平衡鉴别器和生成器。与典型的 GAN 技术相比,我们的方法具有更简单的训练过程和使用更简单的神经网络结构。

判别器

判别器的结构是自编码器结构,输入是真实图片或生成图片,然后通过基于 CNN 的 Encoder,将图像编码为 Embedding。接着将 Embedding 通过基于 转置 CNN 的 Decoder 得到重建后的图片。

图片

生成器

生成器的结构和判别器的 Decoder 一致,输入噪声向量,输出生成的图片。

损失函数为:

图片

图片

平衡项

在实践中,保持生成器和判别器损失之间的平衡至关重要。我们认为它们在以下情况下处于平衡状态:

图片

图片

最终的 Loss

BEGAN4.png

BEGAN7.png

收敛性测量

BEGAN5.png

该度量可用于确定网络何时达到其最终状态或模型是否已崩溃。

Code

# 超参数
gamma = 0.75
lambda_k = 0.001
k = 0.0

# 生成噪声
z = Variable(Tensor(np.random.normal(01, (imgs.shape[0], opt.latent_dim))))

# 训练 G
gen_imgs = generator(z)
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs)) # 用了 L1 范数

# 训练 D
d_real = discriminator(real_imgs)   # 输入真实图片的重建
d_fake = discriminator(gen_imgs.detach()) # 输入生成图片的重建
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake

# 更新 k
diff = torch.mean(gamma * d_loss_real - d_loss_fake)  # 与期望的 gamma 差距有多少
k = k + lambda_k * diff.item()       # 以 lambda_k 的学习率进行更新
k = min(max(k, 0), 1)           # 限制 k 属于 [0, 1]

# 计算 收敛性测量
M = (d_loss_real + torch.abs(diff)).data[0]

参考链接:

github.com/eriklindern…

arxiv.org/abs/1703.10…


ONE MORE THING

咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)