【附Python源码】GAN网络实现图像生成

6 阅读5分钟

【附Python源码】GAN网络实现图像生成

生成对抗网络(Generative Adversarial Networks, GAN)自2014年由Ian Goodfellow提出以来,已成为深度学习领域最具影响力的模型架构之一。本项目使用PyTorch实现一个经典的GAN模型,用于生成MNIST手写数字图像。

一、GAN基本原理

GAN的核心思想源于博弈论中的零和博弈,由两个神经网络组件构成:

生成器(Generator) :接收随机噪声作为输入,学习真实数据的分布特征,输出生成的假样本。其目标是欺骗判别器,使其无法区分生成样本与真实样本。

判别器(Discriminator) :接收真实样本或生成样本作为输入,输出一个概率值表示输入为真实样本的可能性。其目标是准确区分真实样本与生成样本。

训练过程中,生成器与判别器交替优化,形成对抗关系。当达到纳什均衡时,生成器能够生成以假乱真的样本。

二、项目架构设计

本项目采用模块化设计,主要包含以下组件:

2.1 数据预处理模块

MNIST数据集包含60000张训练图像和10000张测试图像,每张图像为28×28像素的灰度图。数据预处理流程如下:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

首先通过ToTensor将图像像素值从[0, 255]归一化至[0, 1],再通过Normalize将数据分布调整至[-1, 1]。该归一化策略与生成器输出层的Tanh激活函数相匹配,有助于稳定训练过程。

DataLoader的配置考虑了GPU训练场景:

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

其中pin_memory=True将数据固定在页内存中,可加速CPU到GPU的数据传输;num_workers=4启用多进程数据加载,减少IO等待时间。

2.2 生成器网络结构

生成器的设计目标是将低维噪声向量映射至高维图像空间。本实现采用全连接层与转置卷积层相结合的架构:

class Generator(nn.Module):
    def __init__(self, latent_dim, img_size, channels):
        super(Generator, self).__init__()
        self.init_size = img_size // 4
        
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.init_size ** 2),
            nn.BatchNorm1d(128 * self.init_size ** 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.conv_blocks = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, channels, 4, stride=2, padding=1),
            nn.Tanh()
        )

网络首先通过全连接层将100维噪声向量映射至128×7×7的特征图,随后通过两层转置卷积逐步上采样:7×7→14×14→28×28。BatchNorm层的引入可缓解内部协变量偏移问题,LeakyReLU激活函数则避免了梯度稀疏性。

2.3 判别器网络结构

判别器采用与生成器对称的编码器结构:

class Discriminator(nn.Module):
    def __init__(self, channels, img_size):
        super(Discriminator, self).__init__()
        
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * (img_size // 4) ** 2, 1),
            nn.Sigmoid()
        )

判别器通过两层卷积将28×28的输入图像下采样至7×7的特征图,展平后通过全连接层输出分类概率。Dropout层的加入增强了模型的泛化能力,降低过拟合风险。

三、训练策略

3.1 损失函数设计

本项目采用二元交叉熵损失(BCELoss)作为优化目标:

判别器损失

L_D = -[E[log D(x)] + E[log(1 - D(G(z)))]]

生成器损失

L_G = -E[log D(G(z))]

其中x表示真实样本,z表示噪声向量,G(z)表示生成样本。

3.2 优化器配置

生成器与判别器分别使用独立的Adam优化器:

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

学习率设置为0.0002,beta参数采用(0.5, 0.999)组合。该配置在GAN训练中被广泛验证,可有效缓解训练不稳定问题。

3.3 多GPU训练支持

针对大规模训练场景,本项目实现了多GPU并行支持:

if torch.cuda.device_count() > 1:
    generator = nn.DataParallel(generator)
    discriminator = nn.DataParallel(discriminator)

DataParallel模块自动将批次数据分配至多个GPU进行前向与反向传播计算,显著提升了训练吞吐量。

3.4 断点续训机制

为应对长时间训练任务可能遭遇的意外中断,本项目实现了断点续训功能:

if continue_train and os.path.exists(continue_model_path):
    checkpoint = torch.load(continue_model_path, map_location=device)
    generator.load_state_dict(checkpoint)

通过保存模型状态字典并在训练开始时加载,可从任意检查点恢复训练进度。

四、训练过程分析

GAN的训练是一个动态博弈过程,需要关注以下指标:

判别器损失(D_loss) :反映判别器区分真假样本的能力。理想情况下应维持在0.5附近,过高或过低均表明训练失衡。

生成器损失(G_loss) :反映生成器欺骗判别器的能力。随着训练进行,该值应逐渐降低。

训练过程中采用tqdm库实现进度可视化:

pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
pbar.set_postfix({'D_loss': f'{d_loss.item():.4f}', 'G_loss': f'{g_loss.item():.4f}'})

五、生成效果评估

训练完成后,可通过以下方式评估生成器性能:

样本可视化:生成固定数量的样本并以网格形式展示,直观检验图像质量。

潜在空间插值:在噪声空间中选取两个点进行线性插值,观察生成结果的平滑过渡。若插值结果呈现连续变化,表明生成器学习到了良好的数据流形结构。

本项目的测试模块实现了上述功能,生成结果保存在samples目录下供进一步分析。

⚠️项目地址:github.com/anjuxi/GAN-…