一、前言
VAE(Variational Auto Encoder)是一种非常强大的自监督学习算法,在 AI 绘画领域发挥重要作用。在 Stable Diffusion 中,VAE 作为其架构上的一大组件。
VAE 可以用来编码解码,同时还具备一定的生成能力。今天我们要做的就是实现一个 VAE 网络,训练并演示其编码解码能力。
二、VAE 原理
在 VAE 之前有一个 AE(Auto Encoder),他们的作用都是用来编码解码。AE 将输入编码成一个向量,而 VAE 则将输入编码成一个分布。将输入编码成分布的的好处在于我们可以从分布中做采样。在介绍 VAE 前,我们简单说说 AE 算法。
2.1 Auto Encoder
Auto Encoder 包含编码器和解码器两个部分,编码器和解码器。编码器将输入编码成一个远小于输入维度的向量,解码器接收编码结果重构输入。具体如图所示:
在 Auto Encoder 中,训练目标是还原出输入,其 Loss 计算为:
其中 e_θ是编码器,d_φ是解码器。可以根据实际情况来修改编码器和解码器的实现。比如常规情况可以使用 MLP 实现编码器和解码器,如果是图像压缩任务则可以使用 CNN 实现编码器和解码器。
这里可以损失可以使用 L2、L1、BCE 等,具体可以根据任务来决定。
2.2 Variational Auto Encoder
VAE 包含编码器和解码器两个部分,编码器将输入编码成两个向量,一个表示分布的均值、一个表示分布的方差,这两个向量可以确定一个分布。因此可以理解为编码器将输入编码成一个分布。解码器则是从编码向量中解码出原内容,这里的编码向量指的是从分布中采样出来的向量。
VAE 的结构如图所示:
在 VAE 中,训练目标是还原出输入,以及保证编码器输出分布与指定分布相似。其 Loss 计算为:
首先重构损失和 AE 基本一样,只是将直接从编码器生成到编码器生成分布后再采样。而另外一个是评估编码器生成的分布和标准正太分布的差异,这里使用的是 KL 散度。
三、实现VAE
下面我们来用 PyTorch 实现 VAE。
3.1 编码、解码
我们编写一个 VAE 类,我们需要完成网络结构、编码、解码,代码如下:
import torch
import torch.nn as nn
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h = F.relu(self.fc1(x))
# 推理均值
mu = self.fc_mu(h)
# 推理方差
logvar = self.fc_logvar(h)
return mu, logvar
def decode(self, z):
# 从向量中解码
h = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h))
在结构上,我们的模型有两个输出,其含义分别为均值和方差。
在编码时,我们通过前向传播得到均值和方差。
在解码时,我们从向量中解码出原内容。
3.2 采样
正常情况下,采样操作为:
但是这个操作是无法计算梯度的,为此我们使用重参数化技巧,将采样操作改为:
在上述公式中,采样结果作为常数参与计算,因此可以计算梯度。现在我们给 VAE 添加 reparameterize 方法:
class VAE(nn.Module):
...
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
# 从正太分布采样
eps = torch.randn_like(std)
# 重参数化技巧
return mu + eps * std
3.3 正向传播
最后我们需要完成正向传播的代码,正向传播的操作如下:
- 编码器生成分布
- 使用重参数化得到 z
- 解码器解码
按照上面步骤实现代码如下:
class VAE(nn.Module):
...
def forward(self, x):
# 1、编码器生成分布
mu, logvar = self.encode(x.view(-1, 784))
# 2、重参数化得到 z
z = self.reparameterize(mu, logvar)
# 3、解码器解码
return self.decode(z), mu, logvar
3.4 计算损失
下面我们来实现计算损失的代码,损失包含重构损失和相似度损失,这里只需要代入两个公式即可:
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
"""
VAE损失函数 = 重构损失 + KL散度损失
"""
# 重构损失
L1 = F.l2_loss(recon_x, x.view(-1, 784), reduction='sum')
# KL散度损失
# KL(N(μ,σ²)||N(0,1)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return L1 + beta * KLD, L1, KLD
四、训练
现在我们来完成训练部分的代码,VAE 的训练和其他网络没有什么区别,这里直接给出代码:
def train_vae(model, device, train_loader, optimizer, epoch, beta=1.0):
model.train()
train_loss = 0
train_l1 = 0
train_kld = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss, l1, kld = vae_loss(recon_batch, data, mu, logvar, beta)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_l1 += l1.item()
train_kld += kld.item()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
f'Loss: {loss.item() / len(data):.6f}')
avg_loss = train_loss / len(train_loader.dataset)
avg_l1 = train_l1 / len(train_loader.dataset)
avg_kld = train_kld / len(train_loader.dataset)
print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}, '
f'L1: {avg_l1:.4f}, KLD: {avg_kld:.4f}')
return avg_loss, avg_l1, avg_kld
这里我们使用前面定义的vae_loss作为损失函数,因为 loss 本身是 l1和 kld 的加权和,因此用于反向传播的只有 loss,l1 和 kld 只用于打印结果。
下面我们使用 MNIST 数据集训练一个 VAE 网络,为了简约,这里舍去评估的代码,代码如下:
def main():
# 设置参数
batch_size = 128
epochs = 10
learning_rate = 1e-3
latent_dim = 20
beta = 1.0 # KL散度权重
# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 模型初始化
model = VAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print(f'Model parameters: {sum(p.numel() for p in model.parameters())}')
# 训练历史记录
train_losses = []
# 训练循环
for epoch in range(1, epochs + 1):
train_loss, train_l1, train_kld = train_vae(
model, device, train_loader, optimizer, epoch, beta
)
train_losses.append(train_loss)
print('Training completed!')
# 保存模型
torch.save(model.state_dict(), 'vae_mnist.pth')
print('Model saved as vae_mnist.pth')
在训练完成后,我们可以重构图片或生成图片,代码如下:
def generate_samples(model, device, num_samples=64, latent_dim=20):
"""生成样本"""
model.eval()
with torch.no_grad():
sample = torch.randn(num_samples, latent_dim).to(device)
sample = model.decode(sample).cpu()
return sample
def reconstruct_images(model, device, test_loader, num_images=8):
"""重构图像"""
model.eval()
with torch.no_grad():
data, _ = next(iter(test_loader))
data = data[:num_images].to(device)
recon_batch, _, _ = model(data)
return data.cpu(), recon_batch.cpu()
生成图片的操作是从正太分布中采样,然后使用解码器解码即可。而重构图片则是执行编码器、重参数化、解码三个操作。
五、总结
VAE 是 AE 的改进模型。VAE 不仅具备压缩能力,同时还具备生成能力。在如今的 Stable Diffusion 中,VAE 作为一个重要模块。由于其自监督学习的特性,VAE 有广泛应用,比如 Musetalk 中 VAE 也作为一个重要模块存在。
另外 VAE 还存在一些变种,比如 Quantized VAE 就是一大代表。感兴趣的读者可以自行了解。