开发也能看懂的大模型:GAN

1,839 阅读7分钟

生成对抗网络(GAN,Generative Adversarial Networks)是由 Ian Goodfellow 等人在 2014 年提出的一种深度学习模型。它采用对抗思想,通过两个网络的博弈,生成逼真的数据,广泛应用于图像生成、数据增强、文本生成等领域。

核心思想

GAN 的核心思想是通过两个网络的对抗学习,达到生成数据与真实数据难以区分的目标。

  1. 生成器(Generator)

    • 作用:根据随机噪声生成数据样本。
    • 目标:生成的数据尽可能“欺骗”判别器,使其认为这些是真实数据。
    • 输入:随机噪声(通常是正态分布或均匀分布)。
    • 输出:与真实数据分布相似的样本。
  2. 判别器(Discriminator)

    • 作用:判断输入的数据是真实数据还是生成数据。
    • 目标:最大化区分真假数据的能力。
    • 输入:生成器生成的数据或真实数据。
    • 输出:真假概率(通常是一个 0 到 1 的数值)。

工作流程

GAN 的训练过程可以看作一个“博弈”过程:

  1. 生成器生成假样本试图欺骗判别器。
  2. 判别器学习识别真假数据,提高判断能力。
  3. 两个网络交替训练,最终达到纳什均衡:生成器生成的数据分布与真实数据分布非常接近,判别器无法区分真假数据(判别准确率约为 50%)。

image.png

特点与优势

  1. 优点

    • 数据生成效果好,尤其是高质量图像。
    • 可以直接生成任意分布的数据。
    • 无需明确的概率模型。
  2. 挑战

    • 训练不稳定:生成器和判别器的训练不平衡会导致模式崩塌(mode collapse)。
    • 优化难度高:需要精心设计网络结构和超参数。
    • 计算资源需求大:生成高分辨率数据时,训练时间较长。

应用场景

  1. 图像生成

    • 生成高分辨率的逼真图片。
    • 人脸生成(如 DeepFake)。
    • 图像超分辨率(SRGAN)。
  2. 数据增强

    • 用于医学影像等领域,生成缺失的数据样本。
  3. 风格转换

    • 如 CycleGAN,用于实现不同风格之间的图像转换(如将照片转为画作风格)。
  4. 文本生成与翻译

    • GAN 的变体(如 SeqGAN)可用于生成文本和语言模型。
  5. 视频生成

    • 生成逼真的动画或视频内容。

进阶与变体

  1. DCGAN(Deep Convolutional GAN)

    • 使用卷积网络改进 GAN,适合处理图像数据。
  2. WGAN(Wasserstein GAN)

    • 改善了 GAN 的训练稳定性,解决了模式崩塌问题。
  3. CycleGAN

    • 实现无监督图像风格转换。
  4. StyleGAN

    • 高质量图像生成的最前沿方法,用于生成超逼真的人脸图片。
  5. BigGAN

    • 针对大规模数据集的 GAN 方法,生成高分辨率的图片。

案例:生成 MNIST 手写数字

使用生成对抗网络 (GAN) 生成 MNIST 手写数字是一种经典的案例,能够很好地展示 GAN 的核心机制和实现方法。以下是详细的解释与分解。


1. 数据集简介:MNIST

image.png


2. GAN 的组成

GAN 由两个部分组成,分别是生成器(Generator)和判别器(Discriminator)。它们通过相互竞争的方式训练,最终生成逼真的手写数字。

image.png


3. GAN 的工作流程

  1. 初始化

    • 初始化生成器和判别器的权重。
    • 定义随机噪声维度(如 100)。
    • 加载 MNIST 数据集,并归一化到 [0,1] 范围。
  2. 训练过程: GAN 的训练分为两步:

    • 训练判别器

      • 从真实数据集中随机采样图片,标记为 1。
      • 使用生成器生成假图片,标记为 0。
      • 优化判别器,使其能够正确分类真假图片。
    • 训练生成器

      • 使用随机噪声生成图片,并标记为“真实”(即希望判别器输出 1)。
      • 优化生成器,使生成的图片更像真实图片,从而“欺骗”判别器。
  3. 迭代训练

    • 判别器与生成器交替训练。
    • 随着训练进行,生成器逐渐生成更逼真的图片,判别器逐渐难以区分真假。

4. 代码解读

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt

(1) 加载 MNIST 数据集

(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0  # 归一化
  • MNIST 数据集包含 28×28 灰度图像。
  • 数据被归一化到 [0,1],以便生成器输出的数据与真实数据范围一致。

(2) 生成器模型

def build_generator(latent_dim):
    model = Sequential([
        Dense(128, activation=LeakyReLU(0.2), input_dim=latent_dim),
        Dense(784, activation='sigmoid'),
        Reshape((28, 28))
    ])
    return model
  • 输入:随机噪声(维度为 latent_dim,例如 100)。

  • 输出:经过全连接层映射为28×28 的图片。

  • 激活函数

    • LeakyReLU:增加非线性能力。
    • Sigmoid:将输出限制在 [0,1],适合生成灰度图片。

(3) 判别器模型

def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(128, activation=LeakyReLU(0.2)),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model
  • 输入:真实图片或生成的假图片。
  • 输出:真假概率(0-1)。
  • 优化器:Adam,适合 GAN 训练。
  • 损失函数:二元交叉熵(binary_crossentropy),用来衡量真假图片的分类错误率。

(4) GAN 模型

def build_gan(generator, discriminator):
    discriminator.trainable = False  # 固定判别器权重
    model = Sequential([generator, discriminator])
    model.compile(optimizer='adam', loss='binary_crossentropy')
    return model
  • 将生成器和判别器组合起来,形成完整的 GAN。
  • 在训练生成器时,冻结判别器的参数。

(5) 训练过程

for epoch in range(epochs):
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images = x_train[idx]

    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    fake_images = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))

    g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

    if epoch % 1000 == 0:
        print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")
  • 随机选取真实图片作为训练样本,生成与之匹配的假图片。
  • 交替训练判别器和生成器。
  • 每隔 1000 次打印损失信息,用于评估模型的训练情况。
    • D Loss Real 表示判别器对真实数据的损失。
    • D Loss Fake 表示判别器对生成数据(假数据)的损失。
    • G Loss 表示生成器的损失。

(6) 可视化生成结果

noise = np.random.normal(0, 1, (10, latent_dim))
generated_images = generator.predict(noise)
for i in range(10):
    plt.subplot(1, 10, i + 1)
    plt.imshow(generated_images[i], cmap='gray')
    plt.axis('off')
plt.show()
  • 随机生成 10 张图片,并展示。
  • 生成器生成的图片应该与真实手写数字高度相似。

5. 训练与结果分析

image.png

image.png

  1. 损失变化

    • 判别器的损失应逐渐减小,表明其能够更好地区分真假数据。
    • 生成器的损失应逐渐减小,表明其生成的图片更接近真实数据。
  2. 生成效果

    • 训练初期:生成器生成的图片模糊、不清晰。
    • 训练后期:生成的图片应与 MNIST 手写数字相似,质量明显提高。

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt

# 生成器模型
def build_generator(latent_dim):
    model = Sequential([
        Dense(128, activation=LeakyReLU(0.2), input_dim=latent_dim),
        Dense(784, activation='sigmoid'),
        Reshape((28, 28))
    ])
    return model

# 判别器模型
def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(128, activation=LeakyReLU(0.2)),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# GAN 模型
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = Sequential([generator, discriminator])
    model.compile(optimizer='adam', loss='binary_crossentropy')
    return model

# 加载数据
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0

# 参数
latent_dim = 100
batch_size = 128
epochs = 10000
generator = build_generator(latent_dim)
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)

# 训练
for epoch in range(epochs):
    # 随机选择真实样本
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images = x_train[idx]

    # 生成假样本
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    fake_images = generator.predict(noise)

    # 训练判别器
    d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))

    # 训练生成器
    g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

    # 打印损失
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")

# 生成图像
noise = np.random.normal(0, 1, (10, latent_dim))
generated_images = generator.predict(noise)
for i in range(10):
    plt.subplot(1, 10, i + 1)
    plt.imshow(generated_images[i], cmap='gray')
    plt.axis('off')
plt.show()