【深度学习Day12】变分自编码器(VAE):给AE装个“概率导航仪”,解锁AI造图技能!

263 阅读10分钟

自编码器(AE)生成图像时常常‘歇菜’,但变分自编码器(VAE)却能将‘城中村’般的Latent空间规整成‘商品房小区’,从而生成高质量的新图像。

摘要:上一期咱们把自编码器(AE)玩明白了,但这货有个致命bug——Latent空间跟“城中村”似的,东一块西一块,随便插值出来的图全是乱码!今天咱给AE升级成变分自编码器(VAE),核心就是给它装个“概率导航仪”,让Latent空间变成“规整的商品房小区”,不仅能降维,还能凭空造图——比如生成不存在的T恤、运动鞋、连衣裙!全程PyTorch实战,用FashionMNIST(灰度服饰数据集)造图,幽默拆解VAE的概率魔法,从原理到落地,带你吃透生成式模型的入门神器~

关键词:变分自编码器VAE、概率建模、图像生成、PyTorch、Latent空间、无监督学习、生成式模型

1. 开篇灵魂拷问:为啥AE造不了图?——城中村vs商品房的神比喻

用AE做降维、去噪贼顺手,但想让它“造图”就歇菜了——比如你拿“猫”的Latent向量z1z_1和“狗”的Latent向量z2z_2,取中间值z3z_3,用AE解码器还原,出来的大概率是“像素乱码”,啥也不是。

为啥?咱用接地气的比喻说透:

  • AE的Latent空间=城中村:房子(数据特征)东一栋西一栋,毫无规律,两个房子中间是荒地,走过去全是坑;
  • VAE的Latent空间=商品房小区:房子按固定规则排布(服从高斯分布),任意两栋房子之间都有平整的路,插值走过去能看到“渐变的房子”(比如从猫渐变到狗)。

VAE的核心牛逼之处:给Latent空间加了“概率规矩” ,让它从“乱糟糟的城中村”变成“规整的商品房小区”,这也是VAE能做“生成任务”的根本原因——AI终于能“创造数据”了!

2. VAE核心原理:给AE装个“概率导航仪”

简单说:VAE = AE + 高斯分布约束——编码器输出的不是固定的Latent向量,而是Latent向量的“均值”和“方差”,通过采样得到zz,迫使所有Latent向量符合正态分布。这也是VAE能做“图像生成”的核心原因。

VAE本质核心就3个新增玩意儿:编码器输出均值+方差、重参数化技巧、KL散度损失。咱用“抽奖”的比喻,从0拆解,新手也能懂。

2.1 先回顾AE的痛点

AE的编码器输出固定的Latent向量z——比如输入猫的图像,就输出一个固定的128维z。问题是:不同猫的zz可能离得十万八千里,中间没有过渡,自然插值不出正常图。

2.2 VAE的核心改进:编码器不输出z,输出“抽奖规则”

VAE的编码器不直接输出z,而是输出两个向量:

  • 均值μ\mu(mu) :抽奖的“中奖号码均值”;
  • 方差σ2\sigma²(sigma²) :抽奖的“号码波动范围”。

然后按这个规则“抽一个奖”,得到真正的Latent向量zz——相当于每输入一张图,不是给一个固定zz,而是给一个“zz的取值范围”,再随机选一个zz出来。

2.3 关键中的关键:重参数化技巧(新手必懂)

如果直接从N(μ,σ2)N(μ, σ²)里抽样得到 zz,反向传播时梯度会断(因为抽样是随机操作,不可导)。VAE的“重参数化”就是给随机操作“开后门”,让梯度能传回去:

重参数化公式z=μ+σϵz = μ + σ\odot\epsilon(ε是从标准正态分布N(0,1)N(0,1)里抽的固定噪声),这里的\odot是逐元素乘法,也就是PyTorch里的*,MATLAB里的.*

通俗理解:把“随机抽样”拆成“固定均值 + 固定噪声×波动范围”,既保留随机性,又能让梯度正常传播——相当于把“盲抽”变成“有规则的抽”,还不破坏训练逻辑。

2.4 VAE的双损失:既要“还原得像”,又要“守规矩”

VAE的训练损失是“重构损失 + KL散度损失”,相当于给模型定了两个规矩:

  1. 重构损失(MSE) :和AE一样,要求还原图和原图长得像(保证生成图有意义);
  2. KL散度损失:要求编码器输出的(μ,σ2)(μ, σ²)尽可能接近标准正态分布N(0,1)N(0,1)(保证Latent空间规整,能插值)。

💡 老鸟踩坑提醒:KL散度的权重是“玄学”!调大了Latent空间贼规整,但生成图模糊(模型只顾守规矩,忘了还原);调小了图清晰,但空间又乱了(回到AE的老问题)。

2.5 VAE的训练逻辑(一步到位)

  1. 输入图像XX → 编码器输出μμσσ
  2. 重参数化抽样得到z=μ+σεz = μ + σ\odotε
  3. 解码器用zz还原出X^
  4. 计算双损失:MSE(X,X^)+KL(N(μ,σ)N(0,1))MSE(X, X̂) + KL(N(μ, σ)|| N(0, 1))
  5. 反向传播优化参数,让两个损失都最小。

3. 代码实战:PyTorch实现VAE,生成FashionMNIST图像

咱用FashionMNIST(灰度服饰数据集,28×28单通道)做实战,实现3个核心功能:训练VAE、生成随机服饰图像、Latent空间插值生成渐变服饰图。代码可直接复制运行。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR

# ========== 1. 环境配置+Fashion-MNIST数据准备 ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")
batch_size = 128
epochs = 30  # 适配Fashion-MNIST的训练轮次
lr = 4e-4    # 微调学习率
latent_dim = 64  # 64维latent适配时尚单品特征

# 预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # (1,28,28),像素0~1
])

# 加载Fashion-MNIST
trainset = torchvision.datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

# Fashion-MNIST类别名称
fashion_classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# ========== 2. 传统VAE模型 ==========
class VAE(nn.Module):  
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 28→14
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14→7
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),# 7→7
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(),  # 128×7×7=6272维
        )
        self.fc_mu = nn.Linear(128 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(128 * 7 * 7, latent_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (128, 7, 7)),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),# 7→7
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 7→14
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),  # 14→28
            nn.Sigmoid()  # 约束0~1
        )

    # 传统VAE重参数化
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decoder(z)
        return x_hat, mu, log_var

# ========== 3. 传统VAE损失函数 ==========
def vae_loss(x_hat, x, mu, log_var):
    # 重构损失:MSELoss
    recon_loss = nn.MSELoss(reduction='sum')(x_hat, x) / x.size(0)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / x.size(0)
    # 总损失:重构损失 + KL损失
    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

# ========== 4. 训练传统VAE ==========
model = VAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=1e-5)  # 余弦退火

def train_vae(model, trainloader, optimizer, scheduler, epochs, device):
    model.train()
    loss_history = []
    recon_loss_history = []
    kl_loss_history = []

    for epoch in range(epochs):
        running_loss = 0.0
        running_recon_loss = 0.0
        running_kl_loss = 0.0

        for i, (inputs, _) in enumerate(trainloader):
            inputs = inputs.to(device)
            outputs, mu, log_var = model(inputs)
            total_loss, recon_loss, kl_loss = vae_loss(outputs, inputs, mu, log_var)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
            running_recon_loss += recon_loss.item()
            running_kl_loss += kl_loss.item()

        scheduler.step()
    return loss_history, recon_loss_history, kl_loss_history

# 开始训练
print("\n===== 开始训练Fashion-MNIST 传统VAE =====")
loss_history, recon_loss_history, kl_loss_history = train_vae(
    model, trainloader, optimizer, scheduler, epochs, device
)

# ========== 6. 随机生成时尚单品 ==========
model.eval()
with torch.no_grad():
    random_z = torch.randn(12, latent_dim).to(device)
    generated_imgs = model.decoder(random_z)
    generated_imgs_np = generated_imgs.cpu().permute(0, 2, 3, 1).squeeze(-1).numpy()
    generated_imgs_np = np.clip(generated_imgs_np, 0.0, 1.0)

# ========== 7. 同类单品插值 ==========
with torch.no_grad():
    # 筛选运动鞋(标签7)做插值
    while True:
        inputs, labels = next(iter(trainloader))
        sneaker_idx = (labels == 7).nonzero().squeeze()
        if len(sneaker_idx) >= 2:
            img1 = inputs[sneaker_idx[0]:sneaker_idx[0]+1].to(device)
            img2 = inputs[sneaker_idx[1]:sneaker_idx[1]+1].to(device)
            break

    # 提取mu插值
    _, mu1, _ = model(img1)
    _, mu2, _ = model(img2)

    steps = 10
    interpolated_z = [(1 - t/(steps-1))*mu1 + t/(steps-1)*mu2 for t in range(steps)]
    interpolated_z = torch.cat(interpolated_z, dim=0)
    interpolated_imgs = model.decoder(interpolated_z)
    interpolated_imgs_np = interpolated_imgs.cpu().permute(0,2,3,1).squeeze(-1).numpy()
    interpolated_imgs_np = np.clip(interpolated_imgs_np, 0.0, 1.0)

可以得到下面随机生成的图像,成功生成出了不同服饰和鞋子的图像。 Copy of vae_fashion_generated.png

通过latent空间插值可以看到两个鞋的图片样本过渡的过程:

Copy of vae_fashion_interpolation.png

4. VAE vs AE:核心区别

对比维度AE(自编码器)VAE(变分自编码器)
Latent空间城中村(离散、无规律)商品房小区(连续、服从高斯分布)
核心能力降维、去噪、压缩(判别式)降维+生成、插值(生成式)
损失函数仅重构损失(MSE)重构损失 + KL散度损失(双损失)
插值效果乱码(中间无有效特征)渐变图(平滑过渡)
训练难度简单(调参少)稍难(KL权重是玄学)
应用场景无标签数据降维、去噪图像生成、风格迁移、数据增广

💡 老鸟经验:如果只是做FashionMNIST这类简单灰度图的降维/去噪,用AE就够了(省算力);如果要生成新服饰、做风格渐变,必须用VAE——别杀鸡用牛刀,也别用菜刀砍坦克~

5. 面试避坑指南(高频问题)

Q1:AE和VAE的核心区别是什么?各自适用场景?

答:核心区别是“Latent空间的性质”:

  • AE的Latent空间离散不连续,适合降维、去噪、压缩等“判别式任务”;
  • VAE的Latent空间连续平滑,适合图像生成、风格迁移等“生成式任务”

Q2:VAE的重参数化技巧是干啥的?为啥必须要?

答:给随机抽样“开后门”!直接从N(μ,σ2)N(μ, σ²)抽样会让梯度断档(随机采样操作不可导),重参数化把抽样拆成“固定均值 + 固定噪声×波动”,既保留随机性,又能让梯度正常传——相当于给随机操作装了“梯度电梯”,不然模型训不了。

Q3:KL散度损失的作用是啥?调大/调小会咋样?

答:KL散度是“规矩监督员”,逼着Latent空间服从标准正态分布。调大了:空间贼规整,但生成图模糊(模型只顾守规矩,忘了还原);调小了:图清晰,但空间又乱了(回到AE的老问题)——就像管孩子,管太严没创造力,管太松没规矩。

Q4:VAE为啥能生成新图像?

答:因为VAE的Latent空间是连续的高斯分布,随便从这个分布里抽一个zz,解码出来都是“符合数据规律的新图像”——相当于小区里随便选个地址,盖出来的房子都符合小区风格,不会是空中楼阁。

📌 下期预告

咱已经搞定了VAE(这可是未来生成模型的重要铺垫!),下一篇直接上“序列数据的敲门砖”——torch.nn里的循环层(RNN基础)+ 嵌入层!之前咱练的CNN专克图像类任务,像个“视觉专家”;但遇到文字、时间序列这类数据就歇菜了,而循环层+嵌入层,就是让AI变身“语言/序列小能手”的核心装备,专门为之后的RNN/LSTM实战打基础!

至于VAE埋下的“生成模型”伏笔,放心!咱后续会开 “生成式模型”专属专栏,从VAE进阶到GAN、Diffusion,把“AI造图、AI写文”的核心逻辑拆得明明白白,再也不用怕“炼丹炉崩了”~ 而现在,咱先稳扎稳打,把循环层、嵌入层这些基础练扎实,后续学高阶内容才会像开了挂!