VAE极简教程

765 阅读6分钟

一、前言

VAE(Variational Auto Encoder)是一种非常强大的自监督学习算法,在 AI 绘画领域发挥重要作用。在 Stable Diffusion 中,VAE 作为其架构上的一大组件。

VAE 可以用来编码解码,同时还具备一定的生成能力。今天我们要做的就是实现一个 VAE 网络,训练并演示其编码解码能力。

二、VAE 原理

在 VAE 之前有一个 AE(Auto Encoder),他们的作用都是用来编码解码。AE 将输入编码成一个向量,而 VAE 则将输入编码成一个分布。将输入编码成分布的的好处在于我们可以从分布中做采样。在介绍 VAE 前,我们简单说说 AE 算法。

2.1 Auto Encoder

Auto Encoder 包含编码器和解码器两个部分,编码器和解码器。编码器将输入编码成一个远小于输入维度的向量,解码器接收编码结果重构输入。具体如图所示:

image.png

在 Auto Encoder 中,训练目标是还原出输入,其 Loss 计算为:

loss=xx^2=xdφ[eθ(x)]2loss = \Vert x-\hat x \Vert_2 = \Vert x-d_φ[e_θ(x)] \Vert_2

其中 e_θ是编码器,d_φ是解码器。可以根据实际情况来修改编码器和解码器的实现。比如常规情况可以使用 MLP 实现编码器和解码器,如果是图像压缩任务则可以使用 CNN 实现编码器和解码器。

这里可以损失可以使用 L2、L1、BCE 等,具体可以根据任务来决定。

2.2 Variational Auto Encoder

VAE 包含编码器和解码器两个部分,编码器将输入编码成两个向量,一个表示分布的均值、一个表示分布的方差,这两个向量可以确定一个分布。因此可以理解为编码器将输入编码成一个分布。解码器则是从编码向量中解码出原内容,这里的编码向量指的是从分布中采样出来的向量。

VAE 的结构如图所示:

image.png

在 VAE 中,训练目标是还原出输入,以及保证编码器输出分布与指定分布相似。其 Loss 计算为:

ux,σx=eθ(x)z=sample(ux,σx)lossreconstruction=xdφ(z)2losssimilarity=DKL(N(ux,σx2)N(0,1))u_x,σ_x = e_θ(x) \\ z = sample(u_x, σ_x) \\ loss_{reconstruction} = \Vert x-d_φ(z) \Vert_2 \\ loss_{similarity} = D_{KL}(N(u_x,σ_x^2)||N(0,1))

首先重构损失和 AE 基本一样,只是将直接从编码器生成到编码器生成分布后再采样。而另外一个是评估编码器生成的分布和标准正太分布的差异,这里使用的是 KL 散度。

三、实现VAE

下面我们来用 PyTorch 实现 VAE。

3.1 编码、解码

我们编写一个 VAE 类,我们需要完成网络结构、编码、解码,代码如下:

import torch
import torch.nn as nn


class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        # 推理均值
        mu = self.fc_mu(h)
        # 推理方差
        logvar = self.fc_logvar(h)
        return mu, logvar

    def decode(self, z):
        # 从向量中解码
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))

在结构上,我们的模型有两个输出,其含义分别为均值和方差。

在编码时,我们通过前向传播得到均值和方差。

在解码时,我们从向量中解码出原内容。

3.2 采样

正常情况下,采样操作为:

zN(μ,σ2) z \sim N(μ, σ²)

但是这个操作是无法计算梯度的,为此我们使用重参数化技巧,将采样操作改为:

εN(0,I)z=μ+σεε \sim N(0, I) \\ z = μ + σ ⊙ ε

在上述公式中,采样结果作为常数参与计算,因此可以计算梯度。现在我们给 VAE 添加 reparameterize 方法:

class VAE(nn.Module):
    ...

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        # 从正太分布采样
        eps = torch.randn_like(std)
        # 重参数化技巧
        return mu + eps * std

3.3 正向传播

最后我们需要完成正向传播的代码,正向传播的操作如下:

  1. 编码器生成分布
  2. 使用重参数化得到 z
  3. 解码器解码

按照上面步骤实现代码如下:

class VAE(nn.Module):
    ...

    def forward(self, x):
        # 1、编码器生成分布
        mu, logvar = self.encode(x.view(-1, 784))
        # 2、重参数化得到 z
        z = self.reparameterize(mu, logvar)
        # 3、解码器解码
        return self.decode(z), mu, logvar

3.4 计算损失

下面我们来实现计算损失的代码,损失包含重构损失和相似度损失,这里只需要代入两个公式即可:

def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """
    VAE损失函数 = 重构损失 + KL散度损失
    """
    # 重构损失
    L1 = F.l2_loss(recon_x, x.view(-1, 784), reduction='sum')
    # KL散度损失
    # KL(N(μ,σ²)||N(0,1)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return L1 + beta * KLD, L1, KLD

四、训练

现在我们来完成训练部分的代码,VAE 的训练和其他网络没有什么区别,这里直接给出代码:

def train_vae(model, device, train_loader, optimizer, epoch, beta=1.0):
    model.train()
    train_loss = 0
    train_l1 = 0
    train_kld = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        recon_batch, mu, logvar = model(data)
        loss, l1, kld = vae_loss(recon_batch, data, mu, logvar, beta)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_l1 += l1.item()
        train_kld += kld.item()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
                  f'Loss: {loss.item() / len(data):.6f}')

    avg_loss = train_loss / len(train_loader.dataset)
    avg_l1 = train_l1 / len(train_loader.dataset)
    avg_kld = train_kld / len(train_loader.dataset)

    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}, '
          f'L1: {avg_l1:.4f}, KLD: {avg_kld:.4f}')

    return avg_loss, avg_l1, avg_kld
    

这里我们使用前面定义的vae_loss作为损失函数,因为 loss 本身是 l1和 kld 的加权和,因此用于反向传播的只有 loss,l1 和 kld 只用于打印结果。

下面我们使用 MNIST 数据集训练一个 VAE 网络,为了简约,这里舍去评估的代码,代码如下:

def main():
    # 设置参数
    batch_size = 128
    epochs = 10
    learning_rate = 1e-3
    latent_dim = 20
    beta = 1.0  # KL散度权重

    # 设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # 模型初始化
    model = VAE(latent_dim=latent_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print(f'Model parameters: {sum(p.numel() for p in model.parameters())}')

    # 训练历史记录
    train_losses = []

    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss, train_l1, train_kld = train_vae(
            model, device, train_loader, optimizer, epoch, beta
        )

        train_losses.append(train_loss)
        
    print('Training completed!')
    # 保存模型
    torch.save(model.state_dict(), 'vae_mnist.pth')
    print('Model saved as vae_mnist.pth')

在训练完成后,我们可以重构图片或生成图片,代码如下:

def generate_samples(model, device, num_samples=64, latent_dim=20):
    """生成样本"""
    model.eval()
    with torch.no_grad():
        sample = torch.randn(num_samples, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        return sample


def reconstruct_images(model, device, test_loader, num_images=8):
    """重构图像"""
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(test_loader))
        data = data[:num_images].to(device)
        recon_batch, _, _ = model(data)

        return data.cpu(), recon_batch.cpu()

生成图片的操作是从正太分布中采样,然后使用解码器解码即可。而重构图片则是执行编码器、重参数化、解码三个操作。

五、总结

VAE 是 AE 的改进模型。VAE 不仅具备压缩能力,同时还具备生成能力。在如今的 Stable Diffusion 中,VAE 作为一个重要模块。由于其自监督学习的特性,VAE 有广泛应用,比如 Musetalk 中 VAE 也作为一个重要模块存在。

另外 VAE 还存在一些变种,比如 Quantized VAE 就是一大代表。感兴趣的读者可以自行了解。