简介
BEGAN 提出了一种新的均衡训练方法,该方法与从 Wasserstein 距离导出的损失相匹配,用于训练基于自动编码器的 GAN。该方法在训练期间平衡生成器和鉴别器。此外,它提供了一种新的近似收敛度量、快速稳定训练和高视觉质量的方法。推导出一种控制图像多样性和视觉质量之间权衡的方法。
典型的 GAN 试图直接匹配数据分布,但我们的方法旨在使用从 Wasserstein 距离导出的损失来匹配自动编码器损失分布。这是使用一个典型的GAN目标完成的,并添加了一个平衡项来平衡鉴别器和生成器。与典型的 GAN 技术相比,我们的方法具有更简单的训练过程和使用更简单的神经网络结构。
判别器
判别器的结构是自编码器结构,输入是真实图片或生成图片,然后通过基于 CNN 的 Encoder,将图像编码为 Embedding。接着将 Embedding 通过基于 转置 CNN 的 Decoder 得到重建后的图片。
生成器
生成器的结构和判别器的 Decoder 一致,输入噪声向量,输出生成的图片。
损失函数为:
平衡项
在实践中,保持生成器和判别器损失之间的平衡至关重要。我们认为它们在以下情况下处于平衡状态:
最终的 Loss
收敛性测量
该度量可用于确定网络何时达到其最终状态或模型是否已崩溃。
Code
# 超参数
gamma = 0.75
lambda_k = 0.001
k = 0.0
# 生成噪声
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# 训练 G
gen_imgs = generator(z)
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs)) # 用了 L1 范数
# 训练 D
d_real = discriminator(real_imgs) # 输入真实图片的重建
d_fake = discriminator(gen_imgs.detach()) # 输入生成图片的重建
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake
# 更新 k
diff = torch.mean(gamma * d_loss_real - d_loss_fake) # 与期望的 gamma 差距有多少
k = k + lambda_k * diff.item() # 以 lambda_k 的学习率进行更新
k = min(max(k, 0), 1) # 限制 k 属于 [0, 1]
# 计算 收敛性测量
M = (d_loss_real + torch.abs(diff)).data[0]
参考链接:
ONE MORE THING
咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)