什么是GAN?
-
GAN,全称 Generative Adversarial Network,即生成对抗网络。该网络模型由Ian J. Goodfellow在2014年首次提出,以下是该论文原文下载地址:Generative Adversarial Nets。
-
生成对抗网络(GAN)是一种通过框架内两个核心模块——生成模型(Generative Model)和判别模型(Discriminative Model)——相互博弈学习,从而产生高质量输出的深度学习模型。作为当前最具前景和活跃度的生成式模型之一,GAN 在样本数据生成、图像合成、图像修复、图像转换以及文本生成等多个领域展现出强大能力,标志着生成式人工智能(AIGC)的关键突破。
GAN 的核心思想是通过生成器和判别器的对抗训练,使生成器能够不断优化以生成逼真的数据,而判别器则不断提升鉴别真伪的能力。这种动态博弈机制使得 GAN 能够生成高度接近真实分布的图像或数据,成为现代生成式 AI 的重要基石之一。
生成对抗网络(GAN)的核心思想是通过**生成器(Generator)和判别器(Discriminator)**的对抗训练,使生成数据的分布逐步逼近真实数据的分布。在训练过程中,生成器从随机噪声中合成样本,并不断优化其生成能力,力求使生成的样本与真实数据尽可能相似,从而"欺骗"判别器。与此同时,判别器则通过对比生成样本和真实样本,持续提升自身的鉴别能力,以更精准地区分两者的差异。这种动态博弈机制推动双方不断优化,最终使生成器能够输出高度逼真的数据。
GAN的工作原理
-
核心构成
GAN由两个重要的部分构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)。
- 生成器:通过机器生成数据,目的是尽可能“骗过”判别器,生成的数据记做G(z);
- 判别器:判断数据是真实数据还是「生成器」生成的数据,目的是尽可能找出「生成器」造的“假数据”。它的输入参数是x,x代表数据,输出D(x)代表x为真实数据的概率,如果为1,就代表100%是真实的数据,而输出为0,就代表不可能是真实的数据。
经过这样的设计,G和D就构成了一个动态对抗的过程,随着多次训练之后,G生成的数据越来越接近真实数据,D判断数据真伪的水平也越来越高。最后在训练的后期,G所生成的数据足够欺骗D,对于D来讲,它则难以判断数据究竟是G生成还是真实数据,因此最后的D(G(z))=0.5。这样我们就得到了一个生成模型可以生成足够以假乱真的数据。
-
训练步骤
- 第一阶段:固定判别器D,训练生成器G。首先使用一个性能不错的判别器D,G通过噪声不断生成假数据,将其丢给D去判断。实验开始时,G生成数据能力还比较弱,很容易就被判别出来。但随着训练的继续,G的生成能力逐渐提升,最终骗过判别器D,这时候D判断是否为假数据的概率为0.5。
- 第二阶段:固定生成器G,训练判别器D。当D判断是否为假数据的概率为0.5,再训练G就没有意义了,此时我们需要训练D。训练D之前,我们先固定G,然后不断训练D。通过不断训练,D提高了自己的鉴别能力,又能够判断出假数据了。
- 不断重复第一阶段与第二阶段:通过不断的训练循环,生成器G和判别器D的能力都很强了,我们就能得到一个生成数据效果很好的生成器G。
GAN的数学原理
注意:该章主要是对GAN文献原文中所涉及到的部分数学原理做介绍,内容相对有难度,请读者按需阅读!
-
GAN中各种数据变量解释
GAN原文的应用是分别训练两个多层感知机来扮演生成器G和判别器D,首先为了训在真实数据上的真实数据分布,我们定义了一个噪声数据上的噪声数据分布,通常该分布可以使用均匀分布、高斯分布等,是实验者人为定义的分布。
接下来,我们定义一个多层感知机,其中是噪声数据,为生成器多层感知机的训练参数。再将上文提到的噪声数据分布作为生成器的输入,并其映射为一个新的数据分布,即生成样本分布,该分布不同于噪声数据分布,该分布可能十分复杂。接下来的训练过程就是将生成样本分布不断逼近真实数据分布。
那么以上的各种表达式就满足以下的数学关系:
接一下我们定义第二个多层感知机,其中是判别器的输入,它可能来源于生成器生成的假数据,也可能来自于真实数据,为判别器多层感知机的训练参数。那么判别器D的输出为一个标量即判别该是真数据还是假数据。
我们可以使用下面这个表格再次理解一下其中的各个变量。
变量 含义 噪声向量 噪声向量的先验分布(如高斯分布、均匀分布等) 真实数据的概率分布 生成器生成的隐式分布 生成器网络的训练参数 生成器网络,将映射为(GAN原文使用的是多层感知机) 判别器网络的训练参数 判别器网络,判别是否来自真实数据(GAN原文使用的是多层感知机) -
GAN的损失函数解析
训练网络得少不了解析损失函数,我们直接给出GAN原文中提到的损失函数,我们再对其进行解析。
损失函数如下:
这个公式看似很复杂,其实是可以理解为两个公式。
-
针对生成器G,损失函数可以理解为:
-
其中,为噪声,G(z)为生成器由噪声生成的假数据,D(G(z))为判别器判别由生成器送来数据的结果。
如果此时D(G(z)) = 0,则代表判别器成功判断出该数据是假数据,那么此时log(1-D(G(z)))就会等于0。如果此时D(G(z)) = 1,则代表判别器没能判断出该数据是假数据,那么此时log(1-D(G(z)))就会趋向于负无穷。所以我们训练生成器的目标就是尽量让判别器出错,这样该损失函数的值就能取得最小值。
注:在GAN原文中指出,早期训练log(1-D(G(z)))时,由于此时的生成器太弱,容易出现判别器赢得对抗,导致生成器无法进行训练优化的情况,在数学上的表现就是训练过程中梯度消失,所以我们在训练早期改用最大化log(D(G(z)))来训练生成器。
-
-
针对判别器D,损失函数可以理解为:
-
其中,x是真实数据,D(x)是判别器判断真实数据的结果。
上文中,我们已经解释加号后部分工作原理,即该部分越大,判别器越能判断出数据是否是假数据,所以该部分对于判别器来说应当取得最大值。接一下我们主要解释加号前部分的工作原理。
此时,若D(x)=1,则判别器成功判别出该真实数据为真实数据,那么log(D(x))就会等于0。若此时D(x)=0,则代表判别器将真实数据判断为假数据,那么log(D(x))就会趋向于负无穷。所以,我们为了训练判别器D,我们就需要让判别器尽量正确判别出数据是否为真数据,即要让该公式取得最大值。
-
-
-
GAN训练过程的图解
注:该图来源于GAN原文
-
图中元素解析
- 黑色虚线:真实数据分布
- 绿色实线:生成器所拟合的数据分布
- 蓝色虚线:判别器的输出概率
- 判别器最佳时,x为真数据时,D(x)=1,x为假数据,D(x)=0。生成器最佳时,D(x)=0.5即判别器只能乱猜数据是否为真。
- 上方水平线:数据 的分布空间
- 下方水平线:噪声的采样空间
- 箭头:生成器G将噪声z映射到数据空间x的过程,即将映射为的过程。
-
图中各个阶段解读
(a)初始阶段
- 绿色实线与黑色虚线差别很大,即生成分布与真实分布差异过大,生成数据质量比较低。
- 蓝色虚线在绿色实线低的位置高,在绿色实线高的位置低,代表判别器D能够初步区分出真实数据与生成数据。
- 总结该阶段:此时生成器还没有能力生成足够欺骗判别器的数据,判别器已经有了初步的判别能力。
(b)判别器优化
- 从(a)到(b)的主要差异是蓝色虚线的变化,蓝色虚线从(a)阶段的有高低起伏趋向于稳定。
- 判别器D趋向于最优解,即判别器在生成数据少的部分能够有效判断出为真实数据,在生成数据多的部分也能有效判断出假数据。
- 总结该阶段:此时生成器被固定依然没有能力生成足够欺骗判别器的数据,而判别器的判别能力趋向最优解。
(c )生成器优化
- 从(b)到(c)的主要差异是绿色实线的变化和下方箭头的变化,绿色实线开始向黑色虚线趋近,箭头也从指向右侧变为指向中部。
- 这两个变化的含义相同,由于生成器的能力不断优化,箭头的变化代表生成器正将噪声z映射到数据空间的变化,越来越接近真实数据分布,这就造成了绿色实线不断靠近黑色虚线,即生成数据的数据分布正逐渐趋近于真实数据分布。
- 总结该阶段:生成器不断优化生成数据的能力,生成数据不断接近真实数据。
(d)收敛阶段
-
从(c)到(d)的主要差异是绿色实线与黑色虚线重合,蓝色虚线变为一条无变化的直线,箭头更加趋近于中部。
-
这些变化的含义都表示此时生成器已经达到最优,箭头的变化代表生成器已经有能力将噪声z映射到数据空间,并且该分布与真实分布完全相同,这就造成绿色实线与黑色虚线完全重合,同时,蓝色虚线代表的判别器的输出概率公式为
由于与真实分布完全相同,使得D(x)恒等于0.5,所以蓝色虚线就变为了一条直线。
-
总结该阶段:此时生成器已经有了能力生成足够欺骗判别器的数据,判别器没有能力再判断数据的真伪,陷入只能瞎猜的境地。
-
针对图解常见问题解答
1. 该过程就是GAN训练的全过程吗?
- 实际上,该图是用训练过程中几个理想片段来表达GAN的训练过程,其中(a)是训练一开始阶段,(d)是训练达到收敛的阶段,而(b)和(c)在实际训练中需要经过多许多次迭代,才能达到(d),即真实训练中,(a)需要经过很多(b)和(c)阶段才能达到(d)。
2. 该图解表述为先训练D,而上文步骤中表述为先训练G,究竟是先训练哪一个?
- 由上一问我们得知,在一次迭代中,生成器和训练器都要进行一次参数更新优化,其中一个网络的性能提升都会带动另外一个网络的性能提升,所以在完整的一个训练过程中一次细微迭代中究竟是先训练G还是D并不会对结果造成太大的影响。
-
-
GAN的训练算法步骤
注:本章是对GAN原文所提及的算法做解释,可能与实际生成中算法有一定出入
以下是GAN原本中提及的算法伪代码:
for 训练迭代次数 do # 步骤1:优化判别器 D(k 次更新) for k steps do 1. 从噪声先验中采样批噪声:{z^(1), ..., z^(m)} ∼ p_z(z) 2. 从真实数据中采样批样本:{x^(1), ..., x^(m)} ∼ p_data(x) 3. 更新判别器参数 θ_d,通过梯度上升: end for # 步骤2:优化生成器 G(1 次更新) 1. 从噪声先验中采样批噪声:{z^(1), ..., z^(m)} ∼ p_z(z) 2. 更新生成器参数 θ_g,通过梯度下降: end for
可能第一次看不明白以上代码究竟是什么含义,接下来我们会做完整介绍。
-
第一层循环是重复迭代次数个循环,这个循环等同于上文中重复多次(b)和(c)的过程。
-
第二层循环是重复k次,k是一个超参数是由实验者人为指定的参数,该层循环等同上文中图解中的(b)过程,只是训练(b)时,我们需要重复k次。
-
判别器的优化过程:
-
首先,我们从噪声先验中采样批噪声:。
-
然后,我们再从真实数据中采样批样本:。
-
将这一批的噪声与数据同时送入到以下损失函数中并计算梯度:
判别器的损失函数构成:
所以,我们将真实数据送入到加号前一项,噪声数据送入到加号后一项,然后我们计算该批次梯度,梯度计算公式如下:
-
接下来,我们做参数更新,由于我们要求的最大值,所以此时应该是梯度上升:
- 其中,和是更新前后的参数,是学习率,为本次计算得到梯度。
-
重复上述过程k次。
注:k的选择需要保证判别器有一定的优化空间,又不至于优化太好,使得生成器的优化受限。
-
-
生成器的优化过程:
-
首先,我们也是从噪声先验中采样批噪声:。
-
再将这一批噪声送入到以下损失函数并计算梯度:
生成器的损失函数构成:
然后,我们计算该批次噪声的梯度:
-
接下来,做参数更新,由于我们此时要求最小值,所以应当使用梯度下降:
- 其中,和是更新前后的参数,是学习率,为本次计算得到梯度。
-
-
最后将以上判别器和生成器的优化过程重复迭代次数即可。
-
GAN代码实例演示——实现手写数字
数据集选择与加载:MNIST数据集
-
MNIST数据集是机器学习领域最经典的入门数据集之一,主要用于手写数字识别任务,该数据集的内容主要包括0到9的手写数字的灰度图片,每张图片大小为28x28像素。该数据集的数据量训练集有60,000张图片,测试集10,000张图片。本文只使用MNIST数据集的训练集部分。
数据集加载与显示代码部分(本文最后设计有全部代码)
注:其中有部分设计到超参数的设置,在一章会有说明
# 加载MNIST数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # 将像素值归一化到[-1, 1] ]) dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 可视化部分 # 定义类别标签 class_names = ['0','1','2','3','4','5','6','7','8','9'] # 从训练集中随机取一个batch的图像 images, labels = next(iter(loader)) # 获取一个batch(64张图) # iter()转换为迭代器,next()获取下一个批次的数据 # images为一个形状为[64, 1, 28, 28]的张量 labels为[64]的张量 # 显示图像函数 def imshow(img): img = img.numpy() img = np.squeeze(img) # 移除单通道维度 (1,28,28) -> (28,28) img = img * 0.5 + 0.5 # 反归一化到[0,1] plt.imshow(img, cmap='gray') plt.axis('off') # 画出一个4x8的网格(共32张图) plt.figure(figsize=(12, 6)) for i in range(32): # 显示前32张 plt.subplot(4, 8, i+1) imshow(images[i]) plt.title(class_names[labels[i].item()], fontsize=8) plt.tight_layout() plt.show()
-
数据集图片演示
超参数设置与网络设计
-
超参数设置一般放在代码的最前面,这一部分并非必需,也可以在后面的代码部分手动设置,这里只是习惯问题。
# 设置超参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备 print(device) lr = 0.0002 # 学习率 z_dim = 64 # 噪声维度 image_dim = 28 * 28 * 1 # MNIST图像维度 batch_size = 64 # 批量大小 epochs = 50 # 训练轮数
-
网络设计:
-
生成器G网络设计
生成器采用三层感知机,其中激活函数选用LeakyReLU函数,斜率设置为0.1,最后的激活函数选择Tan函数。
# 生成器网络设计 class Generator(nn.Module): def __init__(self, z_dim, img_dim): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(z_dim, 128), nn.LeakyReLU(0.1), nn.Linear(128, 256), nn.LeakyReLU(0.1), nn.Linear(256, img_dim), nn.Tanh() ) def forward(self, x): return self.model(x)
-
判别器D网络设计
生成器采用三层感知机,其中激活函数选用LeakyReLU函数,斜率设置为0.1,最后的激活函数选择Sigmoid函数。
# 判别器网络设计 class Discriminator(nn.Module): def __init__(self, img_dim): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(img_dim, 256), nn.LeakyReLU(0.1), nn.Linear(256, 128), nn.LeakyReLU(0.1), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x)
-
网络实例化与循环训练
# 循环轮次
for epoch in range(epochs):
# 提取数据
for i, (real_img, _) in enumerate(loader):
# 1. 训练判别器 log(D(real)) + log(1 - D(G(z)))
# 采样真实数据
batch_size = real_img.shape[0]
real_img = real_img.view(-1, image_dim).to(device)
# 进行判别得到损失函数值
disc_real = discriminator(real_img).flatten()
real_labels = torch.ones_like(disc_real).to(device)
loss_real = criterion(disc_real, real_labels)
# 采样噪声数据
noise = torch.randn(batch_size, z_dim).to(device)
fake_img = generator(noise)
# 进行判别得到损失函数值
disc_fake = discriminator(fake_img.detach()).flatten()
fake_labels = torch.zeros_like(disc_fake).to(device)
loss_fake = criterion(disc_fake, fake_labels)
# 将两者损失值求和除以二,以免其中一个损失值过大影响训练
loss_disc = (loss_real + loss_fake) / 2
# 更新参数
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# 训练生成器 最小化 log(1 - D(G(z))) → 最大化 log(D(G(z)))
# 将噪声数据采样进行判别
output = discriminator(fake_img).flatten()
# 计算损失函数值
loss_gen = criterion(output, real_labels)
# 更新参数
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
打印结果与结果保存
# 生成结果保存文件夹
os.makedirs("generated_images", exist_ok=True)
# 打印数据并保存图像数据
if i == 0:
print(
f"Epoch [{epoch+1}/{epochs}] "
f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
)
with torch.no_grad():
noise = torch.randn(batch_size, z_dim).to(device)
fake = generator(noise).reshape(-1, 1, 28, 28)
img_grid = torchvision.utils.make_grid(fake, nrow=4, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title(f"Epoch {epoch+1}")
plt.savefig(f"generated_images/epoch{epoch+1}.png")
#plt.show()
plt.close()
全部代码
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # 允许重复加载OpenMP库
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 设置超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备
print(device)
lr = 0.0002 # 学习率
z_dim = 64 # 噪声维度
image_dim = 28 * 28 * 1 # MNIST图像维度
batch_size = 64 # 批量大小
epochs = 50 # 训练轮数
os.makedirs("generated_images", exist_ok=True)
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 将像素值归一化到[-1, 1]
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 可视化部分
# 定义类别标签
class_names = ['0','1','2','3','4','5','6','7','8','9']
# 从训练集中随机取一个batch的图像
images, labels = next(iter(loader)) # 获取一个batch(64张图)
# iter()转换为迭代器,next()获取下一个批次的数据
# images为一个形状为[64, 1, 28, 28]的张量 labels为[64]的张量
# 显示图像函数
def imshow(img):
img = img.numpy()
img = np.squeeze(img) # 移除单通道维度 (1,28,28) -> (28,28)
img = img * 0.5 + 0.5 # 反归一化到[0,1]
plt.imshow(img, cmap='gray')
plt.axis('off')
# 画出一个4x8的网格(共32张图)
plt.figure(figsize=(12, 6))
for i in range(32): # 显示前32张
plt.subplot(4, 8, i+1)
imshow(images[i])
plt.title(class_names[labels[i].item()], fontsize=8)
plt.tight_layout()
plt.show()
# 生成器网络设计
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(z_dim, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, img_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# 判别器网络设计
class Discriminator(nn.Module):
def __init__(self, img_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 网络实例化
generator = Generator(z_dim, image_dim).to(device)
discriminator = Discriminator(image_dim).to(device)
# 损失函数选择交叉熵损失函数
criterion = nn.BCELoss()
# 优化器选择Adam优化器
opt_gen = torch.optim.Adam(generator.parameters(), lr=lr)
opt_disc = torch.optim.Adam(discriminator.parameters(), lr=lr)
# 循环轮次
for epoch in range(epochs):
# 提取数据
for i, (real_img, _) in enumerate(loader):
# 1. 训练判别器 log(D(real)) + log(1 - D(G(z)))
# 采样真实数据
batch_size = real_img.shape[0]
real_img = real_img.view(-1, image_dim).to(device)
# 进行判别得到损失函数值
disc_real = discriminator(real_img).flatten()
real_labels = torch.ones_like(disc_real).to(device)
loss_real = criterion(disc_real, real_labels)
# 采样噪声数据
noise = torch.randn(batch_size, z_dim).to(device)
fake_img = generator(noise)
# 进行判别得到损失函数值
disc_fake = discriminator(fake_img.detach()).flatten()
fake_labels = torch.zeros_like(disc_fake).to(device)
loss_fake = criterion(disc_fake, fake_labels)
# 将两者损失值求和除以二,以免其中一个损失值过大影响训练
loss_disc = (loss_real + loss_fake) / 2
# 更新参数
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# 训练生成器 最小化 log(1 - D(G(z))) → 最大化 log(D(G(z)))
# 将噪声数据采样进行判别
output = discriminator(fake_img).flatten()
# 计算损失函数值
loss_gen = criterion(output, real_labels)
# 更新参数
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# 打印数据并保存图像数据
if i == 0:
print(
f"Epoch [{epoch+1}/{epochs}] "
f"Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
)
with torch.no_grad():
noise = torch.randn(batch_size, z_dim).to(device)
fake = generator(noise).reshape(-1, 1, 28, 28)
img_grid = torchvision.utils.make_grid(fake, nrow=4, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title(f"Epoch {epoch+1}")
plt.savefig(f"generated_images/epoch{epoch+1}.png")
#plt.show()
plt.close()
结果展示
- 显然该网络并没有训练到完全拟合,还可以继续增加训练的轮数使得网络训练更加趋近于拟合。