【技术专题】PyTorch2 深度学习 - 生成式对抗网络(GAN)

0 阅读6分钟

大家好,我是锋哥。最近连载更新《PyTorch2 深度学习》技术专题。

image.png 本课程主要讲解基于PyTorch2的深度学习核心知识,主要讲解包括PyTorch2框架入门知识,环境搭建,张量,自动微分,数据加载与预处理,模型训练与优化,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。。 同时也配套视频教程 《PyTorch 2 Python深度学习 视频教程》

生成对抗网络(GAN)简介

GAN简介

生成对抗网络是一种无监督深度学习模型,其核心思想是通过让两个神经网络相互“竞争”或“对抗”来学习。这两个网络分别是:

  1. 生成器: 它的目标是学习真实数据的分布,并生成足以“以假乱真”的虚假数据。它接收一个随机噪声向量作为输入,输出一个伪造的数据样本。
  2. 判别器: 它的目标是成为一个“鉴定专家”,能够正确区分输入的数据是来自真实数据集还是生成器生成的假数据。它接收一个数据样本,输出一个标量,表示该样本为真的概率。

对抗过程可以类比为:

  • 生成器 像一个伪造者,努力制作更逼真的假画。
  • 判别器 像一个鉴定专家,努力识别出画作的真伪。
  • 两者不断博弈,最终伪造者(生成器)的技能变得如此高超,以至于它生成的画作让鉴定专家(判别器)也无法分辨真伪(输出概率接近0.5)

GAN在多个领域具有广泛的应用,特别是在生成式任务中,如:

  • 图像生成:生成高度真实的图像,如人脸图像(例如DeepFake)和艺术风格转换。
  • 图像超分辨率:提高图像分辨率,生成高清晰度图像。
  • 图像修复:填补图像中的缺失部分,修复损坏的图像。
  • 数据增强:生成更多的训练样本,尤其在数据稀缺时。
  • 文本到图像生成:根据文本描述生成对应的图像(例如“一个红色的苹果”生成一张苹果的图像)

GAN基本架构图

image.png

GAN损失函数的定义

image.png

生成对抗网络(GAN)示例

我们以生成手写数字数据集为示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
​
# 创建保存图像的文件夹
os.makedirs("gan_images", exist_ok=True)
os.makedirs("models", exist_ok=True)
​
# 超参数设置
batch_size = 128
latent_dim = 100  # 噪声向量维度
hidden_dim = 256
image_dim = 28 * 28  # MNIST图像尺寸
num_epochs = 100
learning_rate = 0.0002
sample_interval = 200  # 每隔多少批次保存一次样本# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 将像素值从[0,1]归一化到[-1,1]
])
​
# 加载MNIST数据集[citation:2][citation:3]
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)
​
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)
​
​
# 定义生成器[citation:2][citation:5]
class Generator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(inplace=True),  # inplace=True参数表示直接在原tensor上进行修改,不创建新的tensor,可以节省内存空间
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim * 4, output_dim),
            nn.Tanh()  # 输出范围[-1, 1],与预处理匹配
        )
​
    # 生成器类的前向传播函数。功能是接收随机噪声z作为输入,通过内部模型生成图像数据,并将输出重塑为28x28的单通道图像格式。
    def forward(self, z):
        img = self.model(z)
        """
        将输入的图像张量重新调整形状为(batch_size, 1, 28, 28)的格式。具体来说:
        img.size(0)保持批次大小不变
        将图像转换为单通道(1)
        调整图像尺寸为28×28像素
        使用view()方法进行张量重塑,不改变数据内容只改变维度排列
        """
        img = img.view(img.size(0), 1, 28, 28)
        return img
​
​
# 定义判别器[citation:2][citation:5]
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),  # 0.2 是负轴部分的斜率系数,当输入为负值时,输出为输入值乘以0.2
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # 输出0-1的概率值
        )
​
    def forward(self, img):
        """
            这段代码的功能是将输入的图像数据进行展平操作。
            具体来说:
            img.view() 是PyTorch中的张量重塑方法
            img.size(0) 获取批次大小(batch size)
            -1 表示自动计算该维度的大小,将图像的所有像素展平成一维向量
            结果是将形状为 [batch_size, channels, height, width] 的图像张量转换为 [batch_size, channels×height×width] 的二维张量
            这样做的目的是将多维的图像数据转换为全连接层可以处理的一维向量格式。
        """
        flattened = img.view(img.size(0), -1)
        validity = self.model(flattened)
        return validity
​
​
# 初始化模型
generator = Generator(latent_dim, hidden_dim, image_dim)
discriminator = Discriminator(image_dim, hidden_dim)
​
# 定义损失函数和优化器[citation:2][citation:3]
"""
这段代码创建了一个Adam优化器用于训练生成器。
功能解释:
optim.Adam() - 创建Adam优化算法实例
generator.parameters() - 获取生成器模型的所有可训练参数
lr=learning_rate - 设置学习率为指定值
betas=(0.5, 0.999) - 设置Adam算法的两个动量参数,分别控制一阶和二阶矩估计的指数衰减率
该优化器将用于更新生成器的权重参数以最小化损失函数。
"""
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
​
# 用于可视化的固定噪声
fixed_noise = torch.randn(64, latent_dim)
​
# 训练统计
d_losses = []
g_losses = []
​
print("开始训练GAN...")
​
# 训练循环[citation:2][citation:3][citation:5]
for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(train_loader):
        batch_size = real_imgs.size(0)
​
        # 创建标签
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
​
        # ---------------------
        #  训练判别器
        # ---------------------
        optimizer_D.zero_grad()
​
        # 计算真实图像的损失
        real_output = discriminator(real_imgs)  # 判别器对真实图像进行预测
        d_loss_real = adversarial_loss(real_output, real_labels)  # 计算真实图像的损失
​
        # 生成假图像
        z = torch.randn(batch_size, latent_dim)  # 生成随机噪声向量
        fake_imgs = generator(z)  # 生成假图像
​
        # 计算假图像的损失
        fake_output = discriminator(fake_imgs.detach())  # 判别器对生成的假图像进行预测
        d_loss_fake = adversarial_loss(fake_output, fake_labels)  # 计算假图像的损失
​
        # 总判别器损失
        d_loss = (d_loss_real + d_loss_fake) / 2  # 计算判别器的损失
        d_loss.backward()  # 反向传播
        optimizer_D.step()  # 更新判别器的权重参数
​
        # ---------------------
        #  训练生成器
        # ---------------------
        optimizer_G.zero_grad()  # 清空生成器的梯度
​
        # 生成器希望假图像被判别为真
        output = discriminator(fake_imgs)  # 判别器对生成的假图像进行预测
        g_loss = adversarial_loss(output, real_labels)  # 计算生成器的损失
​
        g_loss.backward()  # 反向传播
        optimizer_G.step()  # 更新生成器的权重参数
​
        # 记录损失
        d_losses.append(d_loss.item())  # 记录判别器的损失
        g_losses.append(g_loss.item())  # 记录生成器的损失
​
        # 定期保存生成的图像[citation:3]
        if i % sample_interval == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(train_loader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
​
            # 保存生成图像示例
            with torch.no_grad():  # 禁用梯度计算
                fake_imgs_sample = generator(fixed_noise).detach().cpu()  # 生成假图像
​
                # 保存为图像网格
                torchvision.utils.save_image(
                    fake_imgs_sample,
                    f"gan_images/epoch_{epoch}_batch_{i}.png",
                    nrow=8,
                    normalize=True
                )

运行代码后,相对目录下生成图片:

image.png

我们打开前面的,比较差。

image.png

打开后面的一些,就很不错了。

image.png