⌈ 传知代码 ⌋ 生成对抗网络GAN详解与实现

174 阅读5分钟

前情提要

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


💡本章重点

  • 生成对抗网络GAN详解与实现

🍞一. 概述

生成对抗网络(Generative Adversarial Networks, GAN)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN的核心思想是通过两个神经网络的对抗训练来生成逼真的数据。它包含两个主要部分:生成器(Generator)和判别器(Discriminator)。该模型在扩散模型被广泛应用之前一直是图像生成领域非常重要的一个模型,现如今即使再扩散模型强大的冲击下,该模型在如妆容迁移等领域依然有着非常广泛的应用。本文将详细介绍该模型并在MNIST数据集上进行实现。

  • Goodfellow I J, Pouget-Abadie J, Mirza M, et al. Generative Adversarial Networks[J]. arXiv, 2014

🍞二. 演示效果

在这里插入图片描述


🍞三.GAN详解

GAN的核心思想是通过两个神经网络的对抗训练来生成逼真的数据。它包含两个主要部分:生成器(Generator)和判别器(Discriminator)。生成器的任务是从随机噪声中生成假数据。这些假数据尽可能逼真,以至于能骗过判别器。

生成器通过一个神经网络,将输入的随机噪声转化为与真实数据分布相似的输出。)判别器的任务是区分真实数据和生成器生成的假数据。

它通过一个神经网络来学习,并输出一个概率值,表示输入数据是来自真实数据分布的概率。两者相互博弈最终达到纳什均衡的状态从而使得生成器能够生成逼真的图片。

在这里插入图片描述

如何确定整个模型的损失函数呢?对于判别器D而言,输出的结果是对应图片为真实的概率,它需要尽量把真实的图片认为是真实的,把假的图片认为是假的。也就是说判别器对于真实图片的的结果要趋于1,对于假的图片的结果要趋于0,这种趋向是概率的趋向,也就是让两个概率的距离为0,设这个概率的距离为 d(p,q),那么对于判别器而言它的损失函数为:

在这里插入图片描述

同理对于生成器来说,目标是让判别器判断错,也就是让判别器的结果趋向于1,因此:

在这里插入图片描述 其中 z 是一个随机噪声输入。现在让我们来考虑距离函数 d 的形式,一般使用KL散度作为距离函数: 在这里插入图片描述

如果令 p 是一个常分布,也就是跟我们的模型没有关系,那么 H(p)会是一个常数,最后的距离函数为:

在这里插入图片描述

因此判别器的损失函数为:

在这里插入图片描述

同理生成器的损失函数为:

在这里插入图片描述


🍞四.实现

现在来一步步实现GAN,首先我们选择的数据集是MNIST手写字数据库,首先我们导入该数据库并选择一部分数据来展示:

train_dataset = datasets.MNIST(root='./data',
                               train = True,
                               download = True,
                               transform = transforms.Compose(
                                            [transforms.ToTensor(), 
                                             transforms.Normalize((0.5,), (0.5,))]
                                           )
                              )

dataloader = DataLoader(train_dataset,
                        batch_size = 128, 
                        shuffle = True
                       )

在这里插入图片描述 现在我们来定义生成器和判别器,这是该模型最关键的部分,我们要尽量让这两个器的复杂度或者能力一致而不至于一方很快就超过了另一方,在这个实现中,我们采用比较简单的全连接网络即可

# Generator
nn.Sequential(
            nn.Linear(self.latent_size, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Tanh()
        )

# Discriminator

nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

然后便是训练过程,我们只要根据上面的损失函数一步一步写即可,在每一个batch中我们先对判别器进行优化,再对生成器进行优化

discriminator_optimizer.zero_grad()
true = discriminator(data)
    
fake = generator(z)
false = discriminator(fake.detach())
    
loss_d = loss_fn(true, torch.ones_like(true)-0.1 * torch.rand_like(true)) \
             + loss_fn(false, torch.zeros_like(false) + 0.1 * torch.rand_like(true))
loss_d.backward()
discriminator_optimizer.step()
    
    # generator
    
generator_optimizer.zero_grad()
false = discriminator(fake)
loss_g = loss_fn(false, torch.ones_like(false)-0.2 * torch.rand_like(true))
loss_g.backward()
generator_optimizer.step()

在这里我们使用了一个训练小技巧,即给了判别器一个松弛度,对于一个图片的真假我不需要你100%的判断正确,只要在90%以上即可,这种软距离的形式可以让判别器在训练过程中不会变的非常严厉从而直接把生成器快速打败。

遍历所有的数据进行训练即可。

for epoch in range(200):
    d_a = 0.
    g_a = 0.
    for step, (img, _) in enumerate(dataloader):
        
        d, g = train(D, G, do, go, loss_fn, img, device)

        d_a += d
        g_a += g

    print(f'epoch {epoch} dloss = {d_a * 128 / 60000} gloss = {g_a * 128/60000} generaging...')
    with torch.no_grad():
        show(G.sample(10), 1, 10)

🍞五.训练过程

在这里插入图片描述


🫓总结

综上,我们基本了解了“一项全新的技术啦” :lollipop: ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读:satisfied:

后续还会继续更新:heartbeat:,欢迎持续关注:pushpin:哟~

:dizzy:如果有错误❌,欢迎指正呀:dizzy:

:sparkles:如果觉得收获满满,可以点点赞👍支持一下哟~:sparkles:

【传知科技 -- 了解更多新知识】