使用 PyTorch 学习生成式人工智能——变分自编码器(VAE)图像生成

255 阅读40分钟

本章内容涵盖

  • 自编码器(Autoencoder)与变分自编码器(Variational Autoencoder, VAE)的区别
  • 构建与训练自编码器以重建手写数字
  • 构建与训练变分自编码器以生成真人脸部图像
  • 使用训练好的变分自编码器进行编码运算与编码插值

到目前为止,你已经学习了如何利用生成对抗网络(GANs)生成各种形状、数字和图像。本章将教你使用另一种生成模型——变分自编码器(VAE)来生成图像。你还将学习通过编码运算和编码插值对VAE的实际应用。

为了理解VAE的工作原理,我们首先需要了解自编码器(AE)。自编码器由两个部分组成:编码器和解码器。编码器将数据压缩到低维空间中的抽象表示(潜在空间),而解码器则将编码后的信息解压并重构数据。自编码器的主要目标是学习输入数据的压缩表示,重点在于最小化重构误差,即原始输入与重构图像之间的差异(如第6章中计算循环一致性损失时的像素级差异)。编码器-解码器结构是各种生成模型的基石,包括你将在本书后半部分详细学习的Transformer。例如,在第9章,你将构建用于机器语言翻译的Transformer:编码器将英文短语转成抽象表示,解码器则基于该表示生成法语翻译。类似地,文本生成图像的Transformer(如DALL-E 2和Imagen)也采用自编码器架构,先将图像编码成低维概率分布,再从该分布解码生成图像。当然,不同模型中编码器和解码器的具体定义会有所不同。

本章的第一个项目是从零开始构建和训练自编码器,用于生成手写数字。你将使用6万个灰度手写数字图像(0到9),每张图像大小为28×28=784像素作为训练数据。自编码器的编码器将每幅图压缩为仅含20个值的确定性向量,解码器则重构图像,目标是最小化原图与重构图像在像素级的平均绝对误差。最终训练得到的自编码器能够生成几乎与训练集相同的手写数字。

虽然自编码器在复制输入数据方面表现良好,但它们往往难以生成训练集中不存在的新样本。更重要的是,自编码器不善于输入数据的插值,即无法有效生成两个输入数据点之间的中间表示。这就引出了变分自编码器(VAE)。VAE与AE有两个关键区别。首先,AE将每个输入编码成潜在空间的一个具体点,而VAE将其编码成潜在空间中的一个概率分布。其次,AE只关注最小化重构误差,而VAE则学习潜变量概率分布的参数,最小化包含重构损失和正则化项——KL散度的损失函数。

KL散度促使潜在空间近似某个分布(本例中为正态分布),确保潜变量不仅仅记忆训练数据,而是捕捉其潜在分布。这样,潜在空间结构良好,相似数据点映射得更接近,使得空间连续且可解释。结果,我们可以操作编码实现新的效果,这也使得编码运算和输入插值在VAE中成为可能。

本章的第二个项目是从头构建并训练VAE,生成真人脸部图像。训练数据为你在第5章下载的戴眼镜图像。VAE的编码器将尺寸为3×256×256=196,608像素的图像压缩为100个遵循正态分布的概率向量,解码器基于此概率向量重构图像。训练好的VAE不仅能复现训练集中的人脸,还能生成新的脸部图像。

你还将学习如何在VAE中进行编码运算和输入插值。通过操作不同输入的编码表示(潜向量),在解码时实现特定效果(例如图像中是否带有某些特征)。潜向量能控制解码图像的不同特征,如性别、是否戴眼镜等。举例来说,假设你分别获得带眼镜男性(z1)、带眼镜女性(z2)和不带眼镜女性(z3)的潜向量。然后计算新向量 z4 = z1 – z2 + z3。由于z1和z2对应的图像都有眼镜,z1 – z2会抵消眼镜特征。同理,z2和z3对应的图像均为女性,z3 – z2抵消女性特征。因此,解码z4将生成一个不带眼镜的男性图像。

你还将创建一系列从带眼镜女性到不带眼镜女性的渐变图像,通过调节潜向量z1和z2的权重。这些练习展示了VAE在生成模型领域的灵活性和创造潜力。

相比我们前几章学习的GANs,AE和VAE结构简单,易于构建。训练上,AE和VAE通常比GAN更稳定、更容易。然而,AE和VAE生成的图像往往比GAN生成的更模糊。GAN在生成高质量、逼真图像方面表现卓越,但训练难度大且计算资源消耗高。选择使用GAN还是VAE,取决于任务的具体需求,如输出质量、计算资源和训练稳定性的重要性。

VAE在实际应用中广泛。例如,假设你经营一家眼镜店,成功在线销售一款男士眼镜新品,但想用同款眼镜拓展女性市场,却缺乏女性佩戴该眼镜的照片,且专业摄影成本高昂。此时,VAE便派上用场:你可以将已有的男士戴眼镜照片与男女无眼镜照片结合,通过编码运算生成佩戴同款眼镜的女性逼真图像,如图7.1所示。这正是你将在本章学习的技术。

image.png

在另一种情景中,假设你的店里销售深色和浅色两种款式的眼镜框,这两种款式都很受欢迎。你希望推出一种中间色调的框架款式。借助VAE,通过一种称为编码插值(encoding interpolation)的方法,你可以轻松生成一系列平滑过渡的图像,如图7.2所示。这些图像从深色框逐渐过渡到浅色框,为顾客提供了一个视觉上的多样化选择范围。

image.png

VAEs 的应用不仅限于眼镜领域,实际上几乎涵盖了所有产品类别,无论是服装、家具还是食品。这项技术为可视化和营销各种产品提供了一种富有创意且经济高效的解决方案。此外,尽管图像生成是一个突出例子,VAEs 也能应用于许多其他类型的数据,包括音乐和文本。它们的多功能性为实际应用开辟了无限可能!

7.1 自编码器(AE)概述

本节将介绍什么是自编码器及其基本结构。为了让你深入理解自编码器的工作原理,你将在本章的第一个项目中构建并训练一个自编码器,用于生成手写数字。本节提供了自编码器架构的概览和完成第一个项目的蓝图。

7.1.1 什么是自编码器(AE)?

自编码器是一种用于无监督学习的神经网络,特别适合于图像生成、压缩和去噪等任务。自编码器由两个主要部分组成:编码器和解码器。编码器将输入压缩成低维表示(潜在空间),解码器则根据该表示重建输入数据。

这种压缩后的表示或潜在空间,捕捉了输入数据中最重要的特征。在图像生成中,这个空间编码了网络所训练图像的关键内容。自编码器因其高效学习数据表示和能够处理无标签数据的能力而广泛应用,适合降维和特征学习等任务。自编码器的一个挑战是编码过程中可能丢失信息,导致重建不够准确。采用更深层次、多隐藏层的架构可以帮助学习更复杂、更抽象的表示,从而潜在地缓解信息丢失问题。另外,训练能够生成高质量图像的自编码器通常计算量较大,需要大量数据集。

正如我们在第一章提到的,最好的学习方式是从零开始动手做。为此,你将在本章第一个项目中学习如何创建一个用于生成手写数字的自编码器。下一小节将为你提供操作蓝图。

7.1.2 构建和训练自编码器的步骤

假设你需要从零开始构建并训练一个自编码器,以生成手写数字的灰度图像,从而掌握使用自编码器处理更复杂任务(如彩色图像生成或降维)的技能。你该如何开始这项工作?

图 7.3 展示了自编码器的架构示意图,以及训练自编码器生成手写数字所涉及的步骤。

image.png

图 7.3 自编码器(AE)的架构及其训练步骤,用于生成手写数字。一个自编码器由编码器(图中左中)和解码器(图中右中)组成。在每次训练迭代中,手写数字图像被输入到编码器(步骤1)。编码器将图像压缩为潜在空间中的确定性点(步骤2)。解码器接收来自潜在空间的编码向量(步骤3),并重建图像(步骤4)。自编码器通过调整参数来最小化重建误差,即原始图像与重建图像之间的差异(步骤5)。

从图中可以看出,自编码器有两个主要部分:编码器(左中)将手写数字图像压缩成潜在空间中的向量,解码器(右中)基于编码向量重建图像。编码器和解码器均为深度神经网络,可能包含不同类型的层,如全连接层、卷积层、转置卷积层等。由于本例涉及手写数字的灰度图像,我们仅使用全连接层。但自编码器也可用于生成更高分辨率的彩色图像;对于这类任务,编码器和解码器通常会包含卷积神经网络(CNN)层。是否使用 CNN 取决于所生成图像的分辨率需求。

构建自编码器时,其参数会被随机初始化。我们需要获取训练集以训练模型:PyTorch 提供了 60,000 张均匀分布于数字 0 至 9 的灰度手写数字图像。图 7.3 左侧显示了三个示例,分别是数字 0、1 和 9 的图像。在训练循环的第一步,我们将训练集中的图像输入编码器。编码器将图像压缩为潜在空间中的 20 维向量(步骤2)。数字 20 并无神秘含义,使用 25 维向量也会获得类似结果。随后,我们将向量表示输入解码器(步骤3),要求其重建图像(步骤4)。我们计算原始图像与重建图像间所有像素的均方误差作为重建损失。该损失将被反向传播,用于更新编码器和解码器的参数,以最小化重建损失(步骤5),从而在下一次迭代中,自编码器能够重建出更接近原图的图像。此过程在整个数据集上重复多轮训练。

模型训练完成后,你可以将未见过的手写数字图像输入编码器,获得其编码表示。随后,将编码输入解码器得到重建图像。你会发现重建图像与原始图像几乎相同。图 7.3 右侧展示了三个重建图像示例,与图左侧对应的原始图像极为相似。

7.2 构建和训练自编码器生成数字图像

既然你已经有了构建和训练自编码器(AE)生成手写数字的蓝图,我们现在深入项目,实现上一节中概述的步骤。

具体来说,本节你将学习如何获取训练集和测试集的手写数字图像;然后用全连接层搭建编码器和解码器;使用训练集训练自编码器,并用训练好的编码器对测试集图像进行编码;最后,利用训练好的解码器重建图像,并与原图进行对比。

7.2.1 获取手写数字图像

你可以通过 Torchvision 库中的 datasets 包下载灰度手写数字图像,方式类似于第 2 章下载服装图像的方法。

首先,下载训练集和测试集:

import torchvision
import torchvision.transforms as T

transform = T.Compose([
    T.ToTensor()
])
train_set = torchvision.datasets.MNIST(root=".",  # ①
                                       train=True, download=True, transform=transform)  # ②
test_set = torchvision.datasets.MNIST(root=".",   # ③
                                      train=False, download=True, transform=transform)

① 通过 torchvision.datasets 的 MNIST() 类下载手写数字图像。
② train=True 表示下载训练集。
③ train=False 表示下载测试集。

这里用的是 MNIST() 类,而不是第 2 章用的 FashionMNIST()。train 参数指示 PyTorch 下载训练集(True)还是测试集(False)。转换前图像像素为 0-255 的整数,ToTensor() 类将其转换为 0-1 浮点数的 PyTorch 张量。训练集有 60,000 张图像,测试集有 10,000 张,数字分布均匀(0~9)。

我们将创建数据批次用于训练和测试,每批含 32 张图像:

import torch

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)

数据准备就绪,接下来构建并训练自编码器。

7.2.2 构建与训练自编码器

自编码器由编码器和解码器两部分组成。我们定义一个 AE() 类,如下所示,表示自编码器结构。

import torch.nn.functional as F
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
input_dim = 784     # ① 输入尺寸是 28×28=784  
z_dim = 20          # ② 潜在变量维度是 20  
h_dim = 200

class AE(nn.Module):
    def __init__(self, input_dim, z_dim, h_dim):
        super().__init__()
        self.common = nn.Linear(input_dim, h_dim)
        self.encoded = nn.Linear(h_dim, z_dim)
        self.l1 = nn.Linear(z_dim, h_dim)
        self.decode = nn.Linear(h_dim, input_dim)

    def encoder(self, x):      # ③ 编码器部分,压缩图像至潜在变量  
        common = F.relu(self.common(x))
        mu = self.encoded(common)
        return mu

    def decoder(self, z):      # ④ 解码器部分,根据编码重建图像  
        out = F.relu(self.l1(z))
        out = torch.sigmoid(self.decode(out))
        return out

    def forward(self, x):      # ⑤ 编码器与解码器组成自编码器  
        mu = self.encoder(x)
        out = self.decoder(mu)
        return out, mu

输入尺寸是 784,因为手写数字灰度图是 28×28 像素。图像先被展平为一维张量,输入自编码器。图像先经过编码器,压缩为 20 维潜在变量。解码器再根据潜在变量重建图像。自编码器输出两个张量:out 为重建图像,mu 为潜在变量编码。

然后,我们实例化 AE() 类,创建自编码器模型。训练时使用 Adam 优化器,如前几章所示:

model = AE(input_dim, z_dim, h_dim).to(device)
lr = 0.00025
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

定义一个函数 plot_digits(),用于在每轮训练后视觉检查重建的手写数字,代码如下:

import matplotlib.pyplot as plt

originals = []   # ① 收集测试集中每个数字的样本图像  
idx = 0
for img, label in test_set:
    if label == idx:
        originals.append(img)
        idx += 1
    if idx == 10:
        break

def plot_digits():
    reconstructed = []
    for idx in range(10):
        with torch.no_grad():
            img = originals[idx].reshape((1, input_dim))
            out, mu = model(img.to(device))  # ② 送入自编码器,获得重建图像  
        reconstructed.append(out)  # ③ 收集重建图像
    imgs = originals + reconstructed
    plt.figure(figsize=(10, 2), dpi=50)
    for i in range(20):
        ax = plt.subplot(2, 10, i + 1)
        img = imgs[i].detach().cpu().numpy()
        plt.imshow(img.reshape(28, 28), cmap="binary")  # ④ 视觉比较原图和重建图  
        plt.xticks([])
        plt.yticks([])
    plt.show()

该函数首先收集了 10 张分别代表不同数字的样本图像,放入列表 originals。将图像送入自编码器,获得重建图像。最后将原图和重建图都绘制出来,以便对比和定期评估自编码器性能。

训练开始前调用 plot_digits() 可视化当前输出:

plot_digits()

你将看到如图 7.4 所示的结果。

image.png

图7.4 显示了训练开始前自编码器(AE)重建图像与原始图像的对比。上排是测试集中10张手写数字的原始图像,下排是AE在训练前的重建图像,重建结果不过是纯噪声。

虽然我们可以像第2章那样将数据划分为训练集和验证集,并训练模型直到验证集表现不再提升,但这里的主要目标是理解自编码器的工作原理,而非追求最佳参数调优。因此,我们将训练AE 10个周期。

代码清单7.3 训练AE生成手写数字:

python
复制
for epoch in range(10):
    tloss = 0
    for imgs, labels in train_loader:                   # ①
        imgs = imgs.to(device).view(-1, input_dim)
        out, mu = model(imgs)                            # ②
        loss = ((out - imgs) ** 2).sum()                 # ③
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tloss += loss.item()
    print(f"at epoch {epoch} total loss = {tloss / len(train_loader)}")
    plot_digits()                                        # ④

① 遍历训练集中的所有批次;
② 使用AE重建图像;
③ 计算重建误差,采用均方误差(MSE);
④ 视觉检查AE的表现。

在每个训练周期内,我们遍历训练集的所有数据批次,将原始图像输入AE,得到重建图像。随后计算重建误差,即原图和重建图像逐像素的差值平方平均。通过Adam优化器调整模型参数,最小化该重建误差。

如果使用GPU训练,模型训练时间约2分钟。你也可以从我的网站下载训练好的模型:mng.bz/YV6K

7.2.3 保存和使用训练好的AE

将模型保存在你电脑的本地文件夹中:

scripted = torch.jit.script(model)
scripted.save('files/AEdigits.pt')

要使用模型重建手写数字图像,加载模型:

model = torch.jit.load('files/AEdigits.pt', map_location=device)
model.eval()

然后调用之前定义的 plot_digits() 函数生成图像:

plot_digits()

输出效果如图7.5所示。

image.png

图7.5 显示了训练好的自编码器(AE)重建图像与原始图像的对比。上排是测试集中10张手写数字的原始图像,下排是训练后AE的重建图像。重建图像与原图十分相似。

重建的手写数字确实与原图相似,尽管重建并非完美,编码-解码过程中会丢失部分信息。不过,相较于生成对抗网络(GAN),自编码器构建更简单,训练时间更短。此外,编码器-解码器架构是许多生成模型的基础,本项目将帮助你理解后续章节的内容,特别是在学习Transformer时。

7.3 什么是变分自编码器(VAE)?

虽然自编码器善于重建原始图像,但它们在生成训练集中未出现的新颖图像时表现不佳。此外,自编码器往往不能将相似的输入映射到潜在空间中彼此相近的点,因此其潜在空间既不连续,也不易解释。例如,你无法在两个输入数据点间进行插值生成有意义的中间表示。基于这些原因,我们将学习自编码器的改进版本:变分自编码器(VAE)。

本节你将首先了解AE和VAE的关键区别,以及为什么这些区别使得VAE能够生成训练集中未见过的逼真图像。随后你将学习训练VAE的一般步骤,并重点学习如何训练一个VAE来生成高分辨率的人脸图像。

7.3.1 AE与VAE的区别

变分自编码器(VAE)由Diederik Kingma和Max Welling于2013年首次提出。它是自编码器的一种变体。与AE一样,VAE也包含两个主要部分:编码器和解码器。

不过,AE和VAE有两个关键区别:

  1. 潜在空间的表示方式不同
    AE的潜在空间是确定性的。每个输入都被映射到潜在空间中的一个固定点。
    而VAE的潜在空间是概率性的。VAE将输入编码为潜在空间中一个可能值的分布,而非单个向量。例如,在本章第二个项目中,我们将把一张彩色图像编码为一个100维的概率向量。我们假设该向量中的每个元素都服从一个独立的正态分布。定义正态分布只需均值(μ)和标准差(σ),因此100维概率向量的每个元素由这两个参数描述。为重建图像,我们从这个分布中采样一个向量,再解码它。VAE的独特之处在于,每次采样都会生成略有不同的输出。

    从统计学角度看,VAE中的编码器试图学习训练数据x的真实分布p(x|Θ),其中Θ是定义该分布的参数。为了便于处理,我们通常假设潜在变量的分布为正态分布。由于只需均值μ和标准差σ就可定义正态分布,真实分布可重写为p(x|Θ)=p(x|μ,σ)。VAE的解码器基于编码器学习到的分布生成样本,即解码器从概率分布p(x|μ,σ)中概率性地生成实例。

  2. 损失函数不同
    训练AE时,我们最小化重建损失,使重建图像尽量接近原图。
    而VAE的损失函数由两部分组成:重建损失和KL散度。KL散度衡量一个概率分布相较于第二个期望分布的偏离程度。在VAE中,KL散度用于正则化编码器,惩罚编码器输出的分布偏离先验分布(标准正态分布)的程度。这鼓励编码器学习有意义且具泛化能力的潜在表示。通过惩罚与先验分布偏差过大的分布,KL散度帮助避免过拟合。

由于我们假设潜在空间的分布为正态分布,KL散度在本例中的计算公式如下(如果假设非正态分布,则公式不同):

image.png

求和是对潜在空间的所有100个维度进行的。当编码器将图像压缩成潜在空间中的标准正态分布,使得均值μ=0,标准差σ=1时,KL散度为0。在其他任何情况下,KL散度的值都会大于0。因此,当编码器成功将图像压缩为潜在空间中的标准正态分布时,KL散度被最小化。

7.3.2 训练VAE生成人人脸图像的蓝图

在本章的第二个项目中,你将从零开始构建并训练一个变分自编码器(VAE),用于生成彩色人脸图像。训练好的模型可以生成训练集中未出现过的新图像。此外,你还可以通过对输入进行插值,生成两个输入数据点之间的中间表示的新颖图像。以下是该项目的整体设计蓝图。

图7.6展示了VAE的架构图以及训练VAE以生成人人脸图像的步骤。

image.png

图7.6 展示了变分自编码器(VAE)的架构及训练步骤,用于生成人人脸图像。VAE由编码器(图中间左上方)和解码器(图中间右下方)组成。在每次训练迭代中,将人人脸图像输入编码器(步骤1)。编码器将图像压缩为潜在空间中的概率点(步骤2;由于假设为正态分布,每个概率点由均值向量和标准差向量描述)。接着,我们从该分布中采样编码,并将采样结果输入解码器。解码器接收采样的编码(步骤3)并重建图像(步骤4)。VAE通过调整参数,最小化重建误差和KL散度的总和。KL散度衡量编码器输出与标准正态分布之间的差异。

图7.6还显示,VAE同样分为编码器(左上)和解码器(右下)两部分。由于第二个项目涉及高分辨率彩色图像,我们将使用卷积神经网络(CNN)构建VAE。如第四章所述,高分辨率彩色图像包含的像素远多于低分辨率灰度图像。如果仅使用全连接层,模型参数会过多,导致训练缓慢且效果差。CNN比同等规模的全连接网络参数更少,学习更快、更有效。

VAE创建完成后,将使用第五章下载的眼镜数据集进行训练。图7.6左侧展示了训练集中三张原始人脸图像。训练循环的第一步,将尺寸为3×256×256=196,608像素的训练图像输入编码器。编码器将图像压缩为潜在空间中的100维概率向量(步骤2;由于假设为正态分布,向量包含均值和标准差)。随后从分布中采样,并将采样向量输入解码器(步骤3),解码器重建图像(步骤4)。我们计算总损失,为像素级重建误差与KL散度(公式7.1)之和。将该损失反向传播以更新编码器和解码器参数,最小化总损失(步骤5)。总损失促使VAE编码输入为更具意义和泛化能力的潜在表示,并重建更接近原图的图像。

模型训练完成后,将人脸图像输入编码器获得编码,再将编码输入解码器获得重建图像。你会发现重建图像与原图相似,但非完全一致。图7.6右侧展示了三张重建图像,与左侧对应的原图相似。

更重要的是,可以丢弃编码器,直接从潜在空间随机采样编码输入训练好的解码器,生成训练集中未出现过的新颖人脸图像。你还可以操控不同输入的编码表示,在解码时实现特定效果。同时,可以通过调整两个编码之间的权重,生成一系列图像,实现从一个实例平滑过渡到另一个实例。

7.4 用变分自编码器(VAE)生成人人脸图像

本节将从零开始创建并训练一个VAE,用于生成人人脸图像,步骤遵循上一节的概要。

相比之前构建和训练AE的方法,第二个项目做了若干改进。首先,由于高分辨率彩色图像像素众多,我们计划在VAE的编码器和解码器中都使用卷积神经网络(CNN)。仅依赖全连接(dense)层会导致参数量过大,学习缓慢且效率低下。其次,为了将图像压缩为潜在空间中的正态分布向量,我们在编码每张图像时会同时生成均值向量和标准差向量,这与AE中固定值向量不同。接着从编码的正态分布中采样得到编码,再通过解码器重建图像。每次从该分布采样重建的图像会略有不同,这正体现了VAE生成新颖图像的能力。

7.4.1 构建VAE

如你所知,第五章下载的眼镜数据集经过人工校正标签后保存在电脑的 /files/glasses/ 文件夹内。我们将图像大小调整为256×256像素,像素值范围为0到1。接着创建一个批量迭代器,每批含16张图像:

transform = T.Compose([
            T.Resize(256),       # ① 调整图像大小为256×256
            T.ToTensor(),        # ② 转换图像为值在0到1之间的张量
            ])
data = torchvision.datasets.ImageFolder(
    root="files/glasses",    # ③ 加载指定文件夹内的图像并应用变换
    transform=transform)     
batch_size=16
loader = torch.utils.data.DataLoader(data,   # ④ 创建数据批处理迭代器
     batch_size=batch_size, shuffle=True)

接下来,我们创建一个包含卷积层和转置卷积层的VAE。首先定义编码器Encoder()类:

latent_dims=100  # ① 潜在空间维度为100

class Encoder(nn.Module):
    def __init__(self, latent_dims=100):  
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1)
        self.batch2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=0)
        self.linear1 = nn.Linear(31*31*32, 1024)
        self.linear2 = nn.Linear(1024, latent_dims)
        self.linear3 = nn.Linear(1024, latent_dims)
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() 
        self.N.scale = self.N.scale.cuda()
    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = F.relu(self.batch2(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)            # ② 编码均值
        std = torch.exp(self.linear3(x)) # ③ 编码标准差
        z = mu + std*self.N.sample(mu.shape)  # ④ 采样得到编码
        return mu, std, z

编码器网络由多个卷积层组成,用于提取输入图像的空间特征。编码器将输入压缩成服从正态分布的向量表示z,具有均值mu和标准差std。编码器输出三个张量:mu、std和z,其中z为从该分布采样的实例。

具体地,输入图像尺寸为(3, 256, 256),首先经过一个步幅为2的Conv2d层(如第四章所述,该步幅表示滤波器每次移动跳过2个像素,达到下采样效果),输出尺寸为(8, 128, 128)。之后经过两层Conv2d,尺寸变为(32, 31, 31)。将其展平并经过线性层,得到mu和std的值。

定义解码器Decoder()类:

class Decoder(nn.Module):   
    def __init__(self, latent_dims=100):
        super().__init__()
        self.decoder_lin = nn.Sequential(        # ① 编码先通过两个全连接层
            nn.Linear(latent_dims, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 31*31*32),           # ② 将编码重塑为多维结构,便于进行转置卷积操作
            nn.ReLU(True))
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32,31,31))
        self.decoder_conv = nn.Sequential(       # ③ 三层转置卷积层组成解码器
            nn.ConvTranspose2d(32,16,3,stride=2, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 3, stride=2, padding=1, output_padding=1))
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)          # ④ 将输出限制在0到1之间,与输入图像像素值范围一致
        return x  

解码器结构是编码器的镜像,不同的是它对编码进行转置卷积操作,将潜在空间的编码逐渐转化回高分辨率彩色图像。

具体流程是:编码先通过两个线性层,再被重塑为(32, 31, 31)的形状,形状对应编码器最后一层Conv2d的输出尺寸。然后通过三层ConvTranspose2d层,结构与编码器卷积层对称。解码器输出尺寸为(3, 256, 256),与训练图像一致。

我们将编码器和解码器合并构建VAE模型:

class VAE(nn.Module):
    def __init__(self, latent_dims=100):
        super().__init__()
        self.encoder = Encoder(latent_dims)   # ① 实例化编码器
        self.decoder = Decoder(latent_dims)   # ② 实例化解码器
    def forward(self, x):
        x = x.to(device)
        mu, std, z = self.encoder(x)          # ③ 输入图像编码
        return mu, std, self.decoder(z)       # ④ 输出编码的均值、标准差和重建图像

VAE由编码器和解码器组成,输入图像通过VAE后输出三个张量:编码的均值mu、标准差std,以及重建图像。

接下来实例化VAE并定义优化器:

vae = VAE().to(device)
lr = 1e-4 
optimizer = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5)

我们将在训练时手动计算重建损失和KL散度损失,因此这里不定义专门的损失函数。

7.4.2 训练VAE

为了训练模型,我们首先定义一个 train_epoch() 函数,用于训练模型一个训练周期(epoch)。
代码示例如下:

def train_epoch(epoch):
    vae.train()
    epoch_loss = 0.0
    for imgs, _ in loader: 
        imgs = imgs.to(device)
        mu, std, out = vae(imgs)                                   # ① 获取重建图像
        reconstruction_loss = ((imgs - out)**2).sum()              # ② 计算重建损失
        kl = ((std**2) / 2 + (mu**2) / 2 - torch.log(std) - 0.5).sum()  # ③ 计算KL散度
        loss = reconstruction_loss + kl                            # ④ 总损失 = 重建损失 + KL散度
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'at epoch {epoch}, loss is {epoch_loss}')

解释:
我们遍历训练集中所有批次,将图像输入VAE得到重建图像。总损失由重建损失与KL散度之和构成。每次迭代都会调整模型参数以最小化该总损失。

我们还定义了一个 plot_epoch() 函数,用于可视化VAE生成的图像:

import numpy as np
import matplotlib.pyplot as plt

def plot_epoch():
    with torch.no_grad():
        noise = torch.randn(18, latent_dims).to(device)        # 从潜在空间随机采样18个向量
        imgs = vae.decoder(noise).cpu()                         # 解码生成图像
        imgs = torchvision.utils.make_grid(imgs, 6, 3).numpy()
        fig, ax = plt.subplots(figsize=(6,3), dpi=100)
        plt.imshow(np.transpose(imgs, (1, 2, 0)))               # 显示图像
        plt.axis("off")
        plt.show()

训练良好的VAE能够将相似输入映射到潜在空间中相近的点,从而使潜在空间更加连续且具有可解释性。这样,我们可以随机从潜在空间采样向量,VAE则能将其解码为有意义的输出。因此,在plot_epoch()函数中,我们每个epoch后随机采样18个潜在向量生成18张图像,并以3×6网格展示,直观检验训练效果。

接下来,训练VAE 10个训练周期:

for epoch in range(1, 11):
    train_epoch(epoch)
    plot_epoch()
torch.save(vae.state_dict(), "files/VAEglasses.pth")

使用GPU训练时,大约需要半小时,否则可能需几小时。训练好的模型权重保存在本地。你也可以从我的网站下载训练好的权重:mng.bz/GNRR ,下载后请务必解压。

7.4.3 使用训练好的VAE生成图像

模型训练完成后,可以用它来生成图像。首先加载本地保存的训练权重:

vae.eval()
vae.load_state_dict(torch.load('files/VAEglasses.pth', map_location=device))

接着检验VAE对图像的重建能力,并比较重建图像与原图的相似度:

imgs, _ = next(iter(loader))
imgs = imgs.to(device)
mu, std, out = vae(imgs)
images = torch.cat([imgs[:8], out[:8], imgs[8:16], out[8:16]], dim=0).detach().cpu()
images = torchvision.utils.make_grid(images, 8, 4)
fig, ax = plt.subplots(figsize=(8, 4), dpi=100)
plt.imshow(np.transpose(images, (1, 2, 0)))
plt.axis("off")
plt.show()

运行以上代码块后,你将看到类似图7.7的输出。

image.png

图7.7 展示了训练好的VAE重建图像与原始图像的对比。第一行和第三行是原始图像,我们将它们输入训练好的VAE,获得的重建图像显示在原图的下方。

原始图像位于第一行和第三行,重建图像位于原始图像下方。重建图像与原始图像相似,如图7.7所示。但在重建过程中,一些信息有所损失,因此它们看起来没有原图那么真实。

接下来,我们测试VAE生成训练集中未见过的新颖图像的能力,调用之前定义的 plot_epoch() 函数:

plot_epoch()

该函数会从潜在空间随机采样18个向量,并将它们输入训练好的VAE以生成18张图像。输出结果如图7.8所示。

image.png

图7.8 训练好的VAE生成的新颖图像。我们从潜在空间随机抽取向量表示,并将它们输入训练好的VAE的解码器。图中显示了这些解码后的图像。由于向量表示是随机抽取的,这些图像并不对应训练集中任何原始图像。

这些图像并非训练集中的内容:编码是从潜在空间随机抽取的,而不是通过编码器对训练集图像编码得到的向量。这是因为VAE的潜在空间是连续且可解释的。潜在空间中新颖且未见过的编码可以被有意义地解码成与训练集相似但不同的图像。

7.4.4 使用训练好的VAE进行编码运算

VAE的损失函数中包含一个正则化项(KL散度),鼓励潜在空间接近正态分布。该正则化项保证潜在变量不仅仅是对训练数据的记忆,而是捕捉底层的分布特征。它有助于构建一个结构良好的潜在空间,使相似的数据点被映射到相近的位置,使空间连续且可解释。因此,我们可以通过操作编码来实现新的生成结果。

为了保证结果可复现,建议你从我的网站(mng.bz/GNRR)下载训练好的权重,并使用本章剩余部分相同的代码块。

正如引言中所述,编码运算(encoding arithmetic)让我们能够生成具有特定特征的图像。为了说明编码运算在VAE中的工作原理,我们首先手动收集以下四组中的三张图片:戴眼镜的男性、未戴眼镜的男性、戴眼镜的女性和未戴眼镜的女性。

torch.manual_seed(0)  
glasses=[]
for i in range(25):                                        ①
    img,label=data[i]
    glasses.append(img)
    plt.subplot(5,5,i+1)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.axis("off")
plt.show()
men_g=[glasses[0],glasses[3],glasses[14]]                  ②
women_g=[glasses[9],glasses[15],glasses[21]]               ③
  
noglasses=[]
for i in range(25):                                        ④
    img,label=data[-i-1]
    noglasses.append(img)
    plt.subplot(5,5,i+1)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.axis("off")
plt.show()
men_ng=[noglasses[1],noglasses[7],noglasses[22]]           ⑤
women_ng=[noglasses[4],noglasses[9],noglasses[19]])        ⑥

① 显示25张戴眼镜的图像
② 选取三张戴眼镜的男性图像
③ 选取三张戴眼镜的女性图像
④ 显示25张未戴眼镜的图像
⑤ 选取三张未戴眼镜的男性图像
⑥ 选取三张未戴眼镜的女性图像

我们每组选择三张图像而非一张,是为了后续编码运算时能够计算同一组中多个编码的平均值。VAE设计的目标是学习输入数据在潜在空间中的分布,通过对多个编码取平均,我们有效地平滑了该空间中的表示。这有助于找到能体现组内不同样本共有特征的平均表示。

接下来,我们将三张戴眼镜的男性图像输入训练好的VAE,获得它们在潜在空间中的编码。然后计算这三张图像编码的平均值,并利用它生成一张戴眼镜的男性重建图像。我们对其他三组执行相同操作。

# 创建一批戴眼镜男性图像
men_g_batch = torch.cat((men_g[0].unsqueeze(0),              ①
             men_g[1].unsqueeze(0),
             men_g[2].unsqueeze(0)), dim=0).to(device)
# 获取三张图像的编码
_,_,men_g_encodings=vae.encoder(men_g_batch)
# 对编码取平均,得到该组的编码
men_g_encoding=men_g_encodings.mean(dim=0)                   ②
# 解码平均编码,生成戴眼镜男性图像
men_g_recon=vae.decoder(men_g_encoding.unsqueeze(0))         ③
  
# 对另外三组做同样操作
# 组2,戴眼镜女性
women_g_batch = torch.cat((women_g[0].unsqueeze(0),
             women_g[1].unsqueeze(0),
             women_g[2].unsqueeze(0)), dim=0).to(device)
# 组3,未戴眼镜男性
men_ng_batch = torch.cat((men_ng[0].unsqueeze(0),
             men_ng[1].unsqueeze(0),
             men_ng[2].unsqueeze(0)), dim=0).to(device)
# 组4,未戴眼镜女性
women_ng_batch = torch.cat((women_ng[0].unsqueeze(0),
             women_ng[1].unsqueeze(0),
             women_ng[2].unsqueeze(0)), dim=0).to(device)
# 获取其他三组的平均编码
_,_,women_g_encodings=vae.encoder(women_g_batch)
women_g_encoding=women_g_encodings.mean(dim=0)
_,_,men_ng_encodings=vae.encoder(men_ng_batch)
men_ng_encoding=men_ng_encodings.mean(dim=0)
_,_,women_ng_encodings=vae.encoder(women_ng_batch)
women_ng_encoding=women_ng_encodings.mean(dim=0)              ④
# 解码各组平均编码
women_g_recon=vae.decoder(women_g_encoding.unsqueeze(0))
men_ng_recon=vae.decoder(men_ng_encoding.unsqueeze(0))
women_ng_recon=vae.decoder(women_ng_encoding.unsqueeze(0))    ⑤

① 创建一批戴眼镜男性图像
② 获取戴眼镜男性组的平均编码
③ 解码该平均编码,生成戴眼镜男性图像
④ 获取另外三组的平均编码
⑤ 解码另外三组的平均编码

这四组的平均编码分别是 men_g_encodingwomen_g_encodingmen_ng_encodingwomen_ng_encoding,其中 g 表示戴眼镜(glasses),ng 表示未戴眼镜(no glasses)。四组解码后的图像分别是 men_g_reconwomen_g_reconmen_ng_reconwomen_ng_recon。我们将这四张图像绘制如下:

imgs=torch.cat((men_g_recon,
                women_g_recon,
                men_ng_recon,
                women_ng_recon),dim=0)
imgs=torchvision.utils.make_grid(imgs,4,1).cpu().numpy()
imgs=np.transpose(imgs,(1,2,0))
fig, ax = plt.subplots(figsize=(8,2),dpi=100)
plt.imshow(imgs)
plt.axis("off")
plt.show()

你将看到如图7.9所示的输出结果。

image.png

图7.9 基于平均编码解码的图像。我们首先在以下四个组中各取三张图片:戴眼镜的男性、戴眼镜的女性、未戴眼镜的男性和未戴眼镜的女性。将这12张图片输入训练好的VAE编码器,得到它们在潜在空间中的编码。然后计算每组三张图片的平均编码。将这四个平均编码输入训练好的VAE解码器,得到四张图像,如图中所示。

图7.9中展示了这四张解码后的图像。它们是代表四个组的综合图像。注意,这些图像与原始的12张图片都不同,但同时保留了各组的典型特征。

接下来,我们对编码进行操作,创建一个新的编码,然后用训练好的VAE解码器解码该编码,观察结果。例如,我们可以用“戴眼镜的男性”的平均编码减去“戴眼镜的女性”的平均编码,再加上“未戴眼镜的女性”的平均编码。将结果输入解码器,观察生成的图像。

z = men_g_encoding - women_g_encoding + women_ng_encoding        ①
out = vae.decoder(z.unsqueeze(0))                                ②
imgs = torch.cat((men_g_recon,
                  women_g_recon,
                  women_ng_recon, out), dim=0)
imgs = torchvision.utils.make_grid(imgs, 4, 1).cpu().numpy()
imgs = np.transpose(imgs, (1, 2, 0))
fig, ax = plt.subplots(figsize=(8, 2), dpi=100)
plt.imshow(imgs)                                                 ③
plt.title("man with glasses - woman with glasses + woman without glasses = man without glasses", fontsize=10, c="r")    ④
plt.axis("off")
plt.show()

① 定义z为“戴眼镜男性”编码减“戴眼镜女性”编码再加“未戴眼镜女性”的编码
② 解码z生成图像
③ 显示四张图像
④ 在图像顶部显示标题

如果你运行上述代码块,将会看到如图7.10所示的输出结果。

image.png

图7.10 使用训练好的VAE进行编码算术的示例。我们首先获得三个组的平均编码:戴眼镜的男性(z1)、戴眼镜的女性(z2)和未戴眼镜的女性(z3)。然后定义一个新的编码 z = z1 – z2 + z3。接着将 z 输入训练好的VAE解码器,得到解码后的图像,如图中最右侧所示。

图7.10中的前三张图像是代表三个输入组的综合图像。最右侧的输出图像是一个未戴眼镜的男性形象。

由于 men_g_encoding 和 women_g_encoding 解码后都会出现眼镜特征,men_g_encoding – women_g_encoding 会抵消生成图像中的眼镜特征。同理,women_ng_encoding 和 women_g_encoding 都代表女性面孔,women_ng_encoding – women_g_encoding 会抵消生成图像中的女性特征。因此,如果你用训练好的VAE解码 men_g_encoding + women_g_encoding – women_ng_encoding,就会得到一个未戴眼镜的男性图像。这个编码算术示例说明,通过操作其他三组的平均编码,可以得到未戴眼镜男性的编码。

练习7.1
修改代码清单7.9,执行以下编码算术:

  • 用戴眼镜男性的平均编码减去未戴眼镜男性的平均编码,再加上未戴眼镜女性的平均编码。将结果输入解码器,观察结果。
  • 用未戴眼镜男性的平均编码减去未戴眼镜女性的平均编码,再加上戴眼镜女性的平均编码。将结果输入解码器,观察结果。
  • 用未戴眼镜女性的平均编码减去未戴眼镜男性的平均编码,再加上戴眼镜男性的平均编码。将结果输入解码器,观察结果。
    记得修改图像标题以反映这些变化。解决方案已在本书GitHub仓库中提供:github.com/markhliu/DG…

此外,我们还可以通过给两个编码分配不同权重来进行潜在空间中的插值,生成一个新的编码。然后对该编码进行解码,创建一个合成图像。通过调整权重,可以生成一系列从一个图像平滑过渡到另一个图像的中间图像。

以戴眼镜女性和未戴眼镜女性的编码为例。定义新的编码 z = w * women_ng_encoding + (1 - w) * women_g_encoding,其中 w 是赋予未戴眼镜女性编码的权重。我们从0到1以0.2为步长逐步改变 w,然后对各编码解码,展示生成的六张图像。

代码清单7.10 两个编码插值生成一系列图像示例:

results = []
for w in [0, 0.2, 0.4, 0.6, 0.8, 1.0]:             ①
    z = w * women_ng_encoding + (1 - w) * women_g_encoding   ②
    out = vae.decoder(z.unsqueeze(0))                   ③
    results.append(out)
imgs = torch.cat((results[0], results[1], results[2],
                  results[3], results[4], results[5]), dim=0)
imgs = torchvision.utils.make_grid(imgs, 6, 1).cpu().numpy()
imgs = np.transpose(imgs, (1, 2, 0))
fig, ax = plt.subplots(dpi=100)
plt.imshow(imgs)                                     ④
plt.axis("off")
plt.show()

① 遍历六个不同的权重值 w
② 在两个编码间进行插值
③ 解码插值编码
④ 显示六张生成的图像

运行代码清单7.10后,你将看到如图7.11所示的输出结果。

image.png

图7.11 插值编码生成一系列中间图像。我们首先获得戴眼镜女性(women_g_encoding)和未戴眼镜女性(women_ng_encoding)的平均编码。插值编码 z 定义为 w * women_ng_encoding + (1 - w) * women_g_encoding,其中 w 是赋予未戴眼镜女性编码的权重。我们将 w 从0以0.2为步长递增到1,共生成六个插值编码。然后对这些编码进行解码,并将生成的六张图像显示在图中。

如图7.11所示,从左到右,图像逐渐从戴眼镜女性过渡到未戴眼镜女性。这表明潜在空间中的编码是连续的、有意义的且可以插值的。

练习7.2
修改代码清单7.10,使用以下编码对创建一系列中间图像:
(i) men_ng_encoding 和 men_g_encoding;
(ii) men_ng_encoding 和 women_ng_encoding;
(iii) men_g_encoding 和 women_g_encoding。
解决方案见本书GitHub仓库:github.com/markhliu/DG…

从下一章开始,你将进入自然语言处理领域,学习生成另一种内容形式——文本。不过,到时你将继续使用许多已学工具,如深度神经网络和编码器-解码器架构。

总结

自编码器(AE)由编码器和解码器两个部分组成。编码器将数据压缩为低维的抽象表示(潜在空间),解码器则将编码信息解压重建数据。
变分自编码器(VAE)同样包括编码器和解码器,但与AE有两点关键区别:
首先,AE将每个输入编码为潜在空间的一个确定点,而VAE将其编码为潜在空间中的概率分布。
其次,AE仅关注最小化重建误差,而VAE在学习潜变量的概率分布参数时,最小化包含重建损失和正则化项(KL散度)的损失函数。

VAE中的KL散度确保潜变量分布接近正态分布,促使编码器学习连续、有意义且具泛化能力的潜在表示。

训练良好的VAE能将相似输入映射到潜在空间的相近点,使潜在空间更连续、更具可解释性。因此,VAE可以对潜在空间中的随机向量进行解码,生成训练集中未见过的有意义输出图像。

VAE的潜在空间是连续且可解释的,区别于AE的潜在空间。我们可以操作这些编码实现新的生成效果,也可以通过调整潜在空间中两个编码的权重,生成一系列从一个实例平滑过渡到另一个实例的中间图像。