【附Python源码】GAN网络实现图像生成
生成对抗网络(Generative Adversarial Networks, GAN)自2014年由Ian Goodfellow提出以来,已成为深度学习领域最具影响力的模型架构之一。本项目使用PyTorch实现一个经典的GAN模型,用于生成MNIST手写数字图像。
一、GAN基本原理
GAN的核心思想源于博弈论中的零和博弈,由两个神经网络组件构成:
生成器(Generator) :接收随机噪声作为输入,学习真实数据的分布特征,输出生成的假样本。其目标是欺骗判别器,使其无法区分生成样本与真实样本。
判别器(Discriminator) :接收真实样本或生成样本作为输入,输出一个概率值表示输入为真实样本的可能性。其目标是准确区分真实样本与生成样本。
训练过程中,生成器与判别器交替优化,形成对抗关系。当达到纳什均衡时,生成器能够生成以假乱真的样本。
二、项目架构设计
本项目采用模块化设计,主要包含以下组件:
2.1 数据预处理模块
MNIST数据集包含60000张训练图像和10000张测试图像,每张图像为28×28像素的灰度图。数据预处理流程如下:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
首先通过ToTensor将图像像素值从[0, 255]归一化至[0, 1],再通过Normalize将数据分布调整至[-1, 1]。该归一化策略与生成器输出层的Tanh激活函数相匹配,有助于稳定训练过程。
DataLoader的配置考虑了GPU训练场景:
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)
其中pin_memory=True将数据固定在页内存中,可加速CPU到GPU的数据传输;num_workers=4启用多进程数据加载,减少IO等待时间。
2.2 生成器网络结构
生成器的设计目标是将低维噪声向量映射至高维图像空间。本实现采用全连接层与转置卷积层相结合的架构:
class Generator(nn.Module):
def __init__(self, latent_dim, img_size, channels):
super(Generator, self).__init__()
self.init_size = img_size // 4
self.fc = nn.Sequential(
nn.Linear(latent_dim, 128 * self.init_size ** 2),
nn.BatchNorm1d(128 * self.init_size ** 2),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv_blocks = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
nn.Tanh()
)
网络首先通过全连接层将100维噪声向量映射至128×7×7的特征图,随后通过两层转置卷积逐步上采样:7×7→14×14→28×28。BatchNorm层的引入可缓解内部协变量偏移问题,LeakyReLU激活函数则避免了梯度稀疏性。
2.3 判别器网络结构
判别器采用与生成器对称的编码器结构:
class Discriminator(nn.Module):
def __init__(self, channels, img_size):
super(Discriminator, self).__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(channels, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25)
)
self.fc = nn.Sequential(
nn.Linear(128 * (img_size // 4) ** 2, 1),
nn.Sigmoid()
)
判别器通过两层卷积将28×28的输入图像下采样至7×7的特征图,展平后通过全连接层输出分类概率。Dropout层的加入增强了模型的泛化能力,降低过拟合风险。
三、训练策略
3.1 损失函数设计
本项目采用二元交叉熵损失(BCELoss)作为优化目标:
判别器损失:
L_D = -[E[log D(x)] + E[log(1 - D(G(z)))]]
生成器损失:
L_G = -E[log D(G(z))]
其中x表示真实样本,z表示噪声向量,G(z)表示生成样本。
3.2 优化器配置
生成器与判别器分别使用独立的Adam优化器:
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
学习率设置为0.0002,beta参数采用(0.5, 0.999)组合。该配置在GAN训练中被广泛验证,可有效缓解训练不稳定问题。
3.3 多GPU训练支持
针对大规模训练场景,本项目实现了多GPU并行支持:
if torch.cuda.device_count() > 1:
generator = nn.DataParallel(generator)
discriminator = nn.DataParallel(discriminator)
DataParallel模块自动将批次数据分配至多个GPU进行前向与反向传播计算,显著提升了训练吞吐量。
3.4 断点续训机制
为应对长时间训练任务可能遭遇的意外中断,本项目实现了断点续训功能:
if continue_train and os.path.exists(continue_model_path):
checkpoint = torch.load(continue_model_path, map_location=device)
generator.load_state_dict(checkpoint)
通过保存模型状态字典并在训练开始时加载,可从任意检查点恢复训练进度。
四、训练过程分析
GAN的训练是一个动态博弈过程,需要关注以下指标:
判别器损失(D_loss) :反映判别器区分真假样本的能力。理想情况下应维持在0.5附近,过高或过低均表明训练失衡。
生成器损失(G_loss) :反映生成器欺骗判别器的能力。随着训练进行,该值应逐渐降低。
训练过程中采用tqdm库实现进度可视化:
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
pbar.set_postfix({'D_loss': f'{d_loss.item():.4f}', 'G_loss': f'{g_loss.item():.4f}'})
五、生成效果评估
训练完成后,可通过以下方式评估生成器性能:
样本可视化:生成固定数量的样本并以网格形式展示,直观检验图像质量。
潜在空间插值:在噪声空间中选取两个点进行线性插值,观察生成结果的平滑过渡。若插值结果呈现连续变化,表明生成器学习到了良好的数据流形结构。
本项目的测试模块实现了上述功能,生成结果保存在samples目录下供进一步分析。
⚠️项目地址:github.com/anjuxi/GAN-…