从零开始打造AI画图大师:条件扩散模型完整实现与无分类器指引详解

0 阅读18分钟

你有没有想过,AI是如何听懂你的指令,画出你想要的东西的?当你对Midjourney输入“一只穿着宇航服的柴犬”,它真的能生成那张图——这背后究竟发生了什么?

今天,我将带你亲手实现一个基础的文本控制AI绘图系统。虽然我们做的是“数字0~9”的控制,但原理,和那些动辄几十亿参数的大模型,完全一致。


一、更上一层楼:让AI听懂你的“命令”

在之前的项目中,我们的扩散模型虽然能生成MNIST手写数字,但它是完全不受控的——你无法告诉它“我要一个数字5”,它生成什么全凭运气。

1.1 核心思维:从无条件到有条件

想象一下你请一个画家画画:

  • 无条件扩散模型

    :你告诉画家“随便画点啥”。他画什么你都只能接受,完全看他的心情。

  • 条件扩散模型

    :你告诉画家“给我画一个数字8”。他听懂了你的指令,专门为你创作一个8。

这就是条件扩散模型的核心思想——我们在神经网络中引入了一个额外的输入,也就是条件y,告诉模型“我想要什么”。

条件y可以是各种各样的东西:一个数字标签(就像我们今天要做的MNIST手写数字)、一段文本描述、一张低分辨率的图像(这就是超分辨率技术),甚至是边缘检测图或姿态关键点。

1.2 简单的实现思路

如何让模型消化这个“条件”呢?关键是把y变成它能理解的数学形式。

  • 神经网络不认识“5”这个整数,就像你不认识外星文一样。

  • 我们需要一个翻译官,把“5”翻译成神经网络能理解的向量

  • 这个翻译官,在深度学习里叫做嵌入层(Embedding Layer) 。

具体的实现思路是这样的:

  • 我们首先仍然使用正弦位置编码,把时间步信息t(比如当前是第100步去噪)变成模型能理解的向量。

  • 然后,我们也用一个嵌入层,把输入的指令y(比如数字“5”)也变成一个特征向量。

  • 最后,简单粗暴但极其有效:把这两个向量相加!此时,模型接收到的信息就同时包含了“时间信息”和“用户指令”。模型自然就知道,在这个时间点,它应该朝着“生成5”的方向去努力了。

这个方法的优雅之处在于,它对原始模型结构的改动极其微小,但效果却立竿见影。


二、代码实战(一):打造能听懂指令的AI画师

纸上得来终觉浅,绝知此事要躬行。下面我们动手实现上面讲的条件扩散模型。
(这里有一个细节需要注意:在本节及后续的所有代码实现中,DDPM的前向加噪过程使用了“累积极大值” \bar{\alpha}t;而在反向去噪计算 \mu\theta 时则使用“原始值” \alpha_t 与上一节的数学推导完全一致。denoise() 函数也必须同时接收这两个参数。

import math
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
 
# ========== 超参数设置 ==========
img_size = 28          # MNIST图像尺寸 28x28
batch_size = 128       # 批次大小
num_timesteps = 1000   # 扩散步数(DDPM的标准配置)
epochs = 10            # 训练轮数
lr = 1e-3              # 学习率
device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
# ========== 辅助函数 ==========
def show_images(images, labels=None, rows=2, cols=10):
    """展示生成的图像(带标签)"""
    fig = plt.figure(figsize=(cols, rows))
    i = 0
    for r in range(rows):
        for c in range(cols):
            ax = fig.add_subplot(rows, cols, i + 1)
            plt.imshow(images[i], cmap='gray')
            if labels is not None:
                ax.set_xlabel(labels[i].item())
            ax.get_axes().set_ticklabels([])
            ax.get_axes().set_ticks([])
            i += 1
    plt.tight_layout()
    plt.show()
 
def _pos_encoding(time_idx, output_dim, device='cpu'):
    """为单个时间步生成正弦位置编码"""
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)
    i = torch.arange(0, D, device=device)
    # 关键的计算公式:div_term = 10000^(2i/D)
    div_term = torch.exp(i / D * math.log(10000))
    # 偶数位用正弦,奇数位用余弦
    v[0::2] = torch.sin(t / div_term[0::2])
    v[1::2] = torch.cos(t / div_term[1::2])
    return v
 
def pos_encoding(timesteps, output_dim, device='cpu'):
    """为批次中的所有时间步生成正弦位置编码"""
    batch_size = len(timesteps)
    device = timesteps.device
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(timesteps[i], output_dim, device=device)
    return v
 
# ========== 卷积块(带时间嵌入) ==========
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_embed_dim):
        super().__init__()
        # 双卷积层:Conv -> BN -> ReLU -> Conv -> BN -> ReLU
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        # MLP将时间嵌入映射到合适的特征维度
        self.mlp = nn.Sequential(
            nn.Linear(time_embed_dim, in_ch),
            nn.ReLU(),
            nn.Linear(in_ch, in_ch)
        )
 
    def forward(self, x, v):
        N, C, _, _ = x.shape
        # 时间嵌入经过MLP后 reshape 为 (N, C, 1, 1)
        v = self.mlp(v)
        v = v.view(N, C, 1, 1)
        # 将时间嵌入加到输入上(特征调制)
        y = self.convs(x + v)
        return y
 
# ========== 条件U-Net模型 ==========
class UNetCond(nn.Module):
    def __init__(self, in_ch=1, time_embed_dim=100, num_labels=None):
        super().__init__()
        self.time_embed_dim = time_embed_dim
 
        # U-Net的编码器(下采样路径)
        self.down1 = ConvBlock(in_ch, 64, time_embed_dim)   # 28 -> 28(保留尺寸经池化->14)
        self.down2 = ConvBlock(64, 128, time_embed_dim)     # 14 -> 14(保留尺寸经池化->7)
        # 瓶颈层(最低分辨率)
        self.bot1 = ConvBlock(128, 256, time_embed_dim)     # 7 -> 7
        # 解码器(上采样路径)
        self.up2 = ConvBlock(128 + 256, 128, time_embed_dim) # 7 -> 14(拼接来自down2的特征)
        self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)   # 14 -> 28(拼接来自down1的特征)
        # 输出层
        self.out = nn.Conv2d(64, in_ch, 1)                  # 1x1卷积输出噪声预测
 
        self.maxpool = nn.MaxPool2d(2)           # 2倍下采样
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')  # 2倍上采样
 
        # ========== 关键!:处理标签的嵌入层 ==========
        if num_labels is not None:
            # 将整数标签(0-9)转换为 time_embed_dim 维的向量
            self.label_emb = nn.Embedding(num_labels, time_embed_dim)
 
    def forward(self, x, timesteps, labels=None):
        # 1. 将时间步转换为正弦位置编码
        t = pos_encoding(timesteps, self.time_embed_dim)
 
        # 2. 如果有标签,将标签转换为嵌入并加到时间编码上
        if labels is not None:
            # label_emb(labels) 的形状是 (batch_size, time_embed_dim)
            # 直接加到时间编码上,两个信号融合
            t += self.label_emb(labels)
 
        # 3. U-Net 前向传播
        # 编码器路径
        x1 = self.down1(x, t)      # 保存用于跳跃连接
        x = self.maxpool(x1)       # 下采样
        x2 = self.down2(x, t)      # 保存用于跳跃连接
        x = self.maxpool(x2)       # 下采样
 
        # 瓶颈层
        x = self.bot1(x, t)
 
        # 解码器路径(带跳跃连接)
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)   # 拼接(跳跃连接)
        x = self.up2(x, t)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)   # 拼接(跳跃连接)
        x = self.up1(x, t)
 
        # 输出噪声预测
        x = self.out(x)
        return x

去噪扩散封装器(Diffuser)

封装正向加噪与反向去噪流程。

class Diffuser:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
        self.num_timesteps = num_timesteps
        self.device = device
        # 线性噪声调度(beta从0.0001线性增加到0.02,DDPM论文的原始配置)
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        self.alphas = 1 - self.betas                     # alpha_t = 1 - beta_t
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)  # \bar{alpha}_t = Π alpha_{1..t}
 
    def add_noise(self, x_0, t):
        """前向扩散:向干净图像添加噪声,得到 x_t"""
        # t 从 1 到 T,索引需要 -1 才能对齐 alpha_bars[0] 对应 t=1
        t_idx = t - 1
        alpha_bar = self.alpha_bars[t_idx]
        # reshape 为 (N, 1, 1, 1) 用于广播
        alpha_bar = alpha_bar.view(alpha_bar.size(0), 1, 1, 1)
 
        # 生成高斯噪声,并与干净图像按公式混合
        noise = torch.randn_like(x_0, device=self.device)
        # x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
        x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
        return x_t, noise
 
    def denoise(self, model, x, t, labels):
        """反向扩散:从 x_t 去噪得到 x_{t-1}"""
        t_idx = t - 1
        # 获取当前时间步的关键参数
        alpha = self.alphas[t_idx]                    # alpha_t
        alpha_bar = self.alpha_bars[t_idx]            # \bar{alpha}_t
        alpha_bar_prev = self.alpha_bars[t_idx-1]     # \bar{alpha}_{t-1}(t=1时自动处理)
 
        N = alpha.size(0)
        alpha = alpha.view(N, 1, 1, 1)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)
        alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
 
        # 使用模型预测噪声
        model.eval()
        with torch.no_grad():
            eps = model(x, t, labels)     # 【关键】:同时传入 labels!
        model.train()
 
        # 计算去噪均值 mu
        # mu = (x - ( (1 - alpha) / sqrt(1 - alpha_bar) ) * eps) / sqrt(alpha)
        mu = (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
        std = torch.sqrt((1 - alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
 
        # 添加噪声(DDPM采样时的随机性)
        noise = torch.randn_like(x, device=self.device)
        noise[t == 1] = 0   # t=1时是最后一步,不添加噪声
        return mu + noise * std
 
    def reverse_to_img(self, x):
        """将张量数据转换为可显示的 PIL 图像"""
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        x = x.cpu()
        to_pil = transforms.ToPILImage()
        return to_pil(x)
 
    def sample(self, model, x_shape=(20, 1, 28, 28), labels=None):
        """从随机噪声开始,逐步去噪生成图像"""
        batch_size = x_shape[0]
        x = torch.randn(x_shape, device=self.device)   # 纯随机噪声开始
        if labels is None:
            # 如果没给标签,就随机生成0~9的标签
            labels = torch.randint(0, 10, (batch_size,), device=self.device)
 
        # 从 T 步逐步去噪到 1 步
        for i in tqdm(range(self.num_timesteps, 0, -1)):
            t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
            x = self.denoise(model, x, t, labels)
 
        # 转换格式并返回
        images = [self.reverse_to_img(x[i]) for i in range(batch_size)]
        return images, labels
 
 
# ========== 数据加载 ==========
preprocess = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='../datasets', download=True, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
# ========== 初始化模型和优化器 ==========
diffuser = Diffuser(num_timesteps, device=device)
model = UNetCond(num_labels=10)   # 10个类别(数字0-9)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
 
# ========== 训练循环 ==========
losses = []
for epoch in range(epochs):
    loss_sum = 0.0
    cnt = 0
 
    # 每个 epoch 结束后生成一组图像,观察训练进展
    images, labels = diffuser.sample(model)
    show_images(images, labels)
 
    for images, labels in tqdm(dataloader):
        optimizer.zero_grad()
        x = images.to(device)
        labels = labels.to(device)      # 【关键】:训练时也要提供标签!
        t = torch.randint(1, num_timesteps+1, (len(x),), device=device)
 
        # 添加噪声并预测
        x_noisy, noise = diffuser.add_noise(x, t)
        noise_pred = model(x_noisy, t, labels)   # 【关键】:模型同时接收t和labels
        loss = F.mse_loss(noise, noise_pred)
 
        loss.backward()
        optimizer.step()
 
        loss_sum += loss.item()
        cnt += 1
 
    loss_avg = loss_sum / cnt
    losses.append(loss_avg)
    print(f'Epoch {epoch} | Loss: {loss_avg:.4f}')
 
# 最终生成展示
images, labels = diffuser.sample(model)
show_images(images, labels)

运行结果: 经过短短10轮训练,模型已经学会了根据标签生成对应的数字。虽然边缘还有些模糊,但它确实成功理解了你的“指令”。


三、AI的进阶之路:从得分函数到分类器指引

上一步的模型虽然能工作了,但它有时候会“偷懒”,不那么看重你给的条件,甚至可能会忽略。为了解决这个问题,我们需要引入一种更强大的技术,它的名字听起来很学术,但原理非常直观,这就是——指引(Guidance) 。

3.1 得分函数——AI内部的“导航仪”

在讨论指引之前,我们需要了解一下扩散模型内部是怎么工作的。扩散模型内部有一个重要的概念叫做得分函数,它是模型判断“这像不像一张真实图像”的内部标尺。数学上定义为对数概率密度相对于输入数据向量的梯度。

一句话理解得分函数:
想象你在一个黑暗的山谷里探索,你蒙着眼睛,目标是走到谷底。得分函数就像你脚下感知坡度的触觉——它会告诉你哪个方向是“下坡”,哪里是“上坡” 。模型就是循着这个“下坡方向”,一步步把噪声“修”成干净图像(数据点会自然聚集在概率高密度的谷底)。

噪声预测模型 \epsilon_\theta(x_t, t) 本质上就是局部梯度的另一种表达形式(存在一个负常数倍关系),因此它其实就在扮演得分函数 s_\theta(x_t, t) 的角色。这也再次验证了扩散模型与基于得分的生成模型是高度统一的:对噪声的预测,本质上等效于对得分的预测。

3.2 分类器指引——给AI装上“GPS”

既然得分函数告诉模型“往哪边走是对的”,那如果我们用分类器告诉模型“往条件 y 的方向走”,不就行了吗?这正是分类器指引的思路。这条方向其实就是条件分类器对当前图像的梯度:\nabla_{x_t} \log p(y|x_t)。

  • 无条件得分:\nabla_{x_t} \log p(x_t)(模型觉得哪条路自然)

  • 分类器梯度:\nabla_{x_t} \log p(y|x_t)(分类器觉得哪条路更符合 y)

将这两股力量按公式“有条件得分=无条件得分+γ×分类器梯度”融合,模型就能在保持自然的同时,坚决朝着指令 y 前进。

缺点也很明显: 你必须额外训练一个独立的分类器。而且这个分类器要处理“加了噪声的模糊图”,和常规训练好的分类器很难完美兼容。

3.3 无分类器指引——一个模型干两份活

既然训练一个独立分类器这么麻烦,能不能用一个模型同时学会“无条件生成”和“有条件生成”,然后在生成时把两者结合起来?这就是大名鼎鼎的无分类器指引(Classifier-Free Guidance,简称CFG) 的核心思想。

原理其实简单到令人惊讶:我们在训练时,让模型以一定比例随机丢掉条件信息。比如10%的概率把labels设为None,让模型在这种情况下进行无条件训练;其余90%的概率正常传labels,进行有条件训练。

  • 当传入labels=None时,模型只根据时间步去噪,学到的就是“无条件得分”。

  • 当传入具体labels时,模型同时利用时间步和标签去噪,学到的就是“有条件得分”。

最后在生成时,CFG按以下公式将两者结合:

最终预测=无条件预测+γ×(有条件预测−无条件预测)最终预测=无条件预测+γ×(有条件预测−无条件预测)

  • \gamma(称为Guidance Scale)越大,模型就越“听话”,生成的图像更贴合你的指令。

  • \gamma 越小,模型就越“自由”,生成的图像更有创意和多样性。

这种方法的优势在于不依赖任何外部预训练分类器,只需一个模型,训练极其简单,生成时又能精准控制“听话”的程度。

补充两条进阶视角:

近年来,学术界仍在持续优化 CFG 的底层理论——例如 2025 年底的研究已开始分析声音信号与几何纠缠的根本原因

;另一组工作则专门解析线性 CFG 所内含的“均值偏移”和“类别特征放大”机制

你熟悉的

“反向提示词”技术

在数学上恰好对应这种无分类器指引:把不需要的信息对应的条件 p(y|x_t) 低维嵌入

Φ

放到无条件部分里,让模型在生成时“踩刹车”绕过它。


四、代码实战(二):无分类器指引的完整实现

现在我们把上面讲的理论转化为可运行的代码,看看 CFG 到底有多简单、多强大。

核心改动有三点:

  1. 训练时随机丢弃条件

    :以一定概率(比如10%)将labels设为None,让模型在无条件模式下训练。

  2. 生成时使用CFG公式

    :同时计算model(x, t, labels)(有条件预测)和model(x, t)(无条件预测),然后按 γ 系数混合。

  3. 配置可供调节的引导系数 γ

    :γ 越大,生成结果越“听指令”;γ 越小,结果越有随机多样性。

import math
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
 
# ========== 超参数 ==========
img_size = 28
batch_size = 128
num_timesteps = 1000
epochs = 10
lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
def show_images(images, labels=None, rows=2, cols=10):
    fig = plt.figure(figsize=(cols, rows))
    i = 0
    for r in range(rows):
        for c in range(cols):
            ax = fig.add_subplot(rows, cols, i + 1)
            plt.imshow(images[i], cmap='gray')
            if labels is not None:
                ax.set_xlabel(labels[i].item())
            ax.get_axes().set_ticklabels([])
            ax.get_axes().set_ticks([])
            i += 1
    plt.tight_layout()
    plt.show()
 
def _pos_encoding(time_idx, output_dim, device='cpu'):
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)
    i = torch.arange(0, D, device=device)
    div_term = torch.exp(i / D * math.log(10000))
    v[0::2] = torch.sin(t / div_term[0::2])
    v[1::2] = torch.cos(t / div_term[1::2])
    return v
 
def pos_encoding(timesteps, output_dim, device='cpu'):
    batch_size = len(timesteps)
    device = timesteps.device
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(timesteps[i], output_dim, device)
    return v
 
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_embed_dim):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(time_embed_dim, in_ch),
            nn.ReLU(),
            nn.Linear(in_ch, in_ch)
        )
 
    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v)
        v = v.view(N, C, 1, 1)
        y = self.convs(x + v)
        return y
 
class UNetCond(nn.Module):
    def __init__(self, in_ch=1, time_embed_dim=100, num_labels=None):
        super().__init__()
        self.time_embed_dim = time_embed_dim
 
        self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
        self.down2 = ConvBlock(64, 128, time_embed_dim)
        self.bot1 = ConvBlock(128, 256, time_embed_dim)
        self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
        self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
        self.out = nn.Conv2d(64, in_ch, 1)
 
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
 
        if num_labels is not None:
            self.label_emb = nn.Embedding(num_labels, time_embed_dim)
 
    def forward(self, x, timesteps, labels=None):
        t = pos_encoding(timesteps, self.time_embed_dim)
 
        if labels is not None:
            t += self.label_emb(labels)
 
        x1 = self.down1(x, t)
        x = self.maxpool(x1)
        x2 = self.down2(x, t)
        x = self.maxpool(x2)
        x = self.bot1(x, t)
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, t)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, t)
        x = self.out(x)
        return x
 
 
# ========== 带 CFG 的 Diffuser ==========
class Diffuser:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
        self.num_timesteps = num_timesteps
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
 
    def add_noise(self, x_0, t):
        t_idx = t - 1
        alpha_bar = self.alpha_bars[t_idx]
        alpha_bar = alpha_bar.view(alpha_bar.size(0), 1, 1, 1)
        noise = torch.randn_like(x_0, device=self.device)
        x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
        return x_t, noise
 
    def denoise(self, model, x, t, labels, gamma):
        """带 CFG 的去噪函数 —— 最关键的部分!"""
        t_idx = t - 1
        alpha = self.alphas[t_idx]
        alpha_bar = self.alpha_bars[t_idx]
        alpha_bar_prev = self.alpha_bars[t_idx-1]
 
        N = alpha.size(0)
        alpha = alpha.view(N, 1, 1, 1)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)
        alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
 
        model.eval()
        with torch.no_grad():
            eps_cond = model(x, t, labels)    # 有条件预测
            eps_uncond = model(x, t)          # 无条件预测
            # CFG 核心公式:最终预测 = 无条件 + gamma * (有条件 - 无条件)
            eps = eps_uncond + gamma * (eps_cond - eps_uncond)
        model.train()
 
        noise = torch.randn_like(x, device=self.device)
        noise[t == 1] = 0
        mu = (x - ((1-alpha) / torch.sqrt(1-alpha_bar)) * eps) / torch.sqrt(alpha)
        std = torch.sqrt((1-alpha) * (1-alpha_bar_prev) / (1-alpha_bar))
        return mu + noise * std
 
    def reverse_to_img(self, x):
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        x = x.cpu()
        to_pil = transforms.ToPILImage()
        return to_pil(x)
 
    def sample(self, model, x_shape=(20, 1, 28, 28), labels=None, gamma=3.0):
        """生成函数,带 CFG 的引导系数 gamma"""
        batch_size = x_shape[0]
        x = torch.randn(x_shape, device=self.device)
        if labels is None:
            labels = torch.randint(0, 10, (batch_size,), device=self.device)
        for i in tqdm(range(self.num_timesteps, 0, -1)):
            t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
            x = self.denoise(model, x, t, labels, gamma)
        images = [self.reverse_to_img(x[i]) for i in range(batch_size)]
        return images, labels
 
 
# ========== 数据加载 ==========
preprocess = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='./data', download=True, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
# ========== 初始化 ==========
diffuser = Diffuser(num_timesteps, device=device)
model = UNetCond(num_labels=10)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
 
# ========== 训练(关键:随机丢弃条件) ==========
losses = []
for epoch in range(epochs):
    loss_sum = 0.0
    cnt = 0
 
    # 每轮结束后生成一次,观察 gamma 效果(可以尝试 gamma=1.5, 3.0, 5.0)
    images, labels = diffuser.sample(model, gamma=3.0)
    show_images(images, labels)
 
    for images, labels in tqdm(dataloader):
        optimizer.zero_grad()
        x = images.to(device)
        labels = labels.to(device)
        t = torch.randint(1, num_timesteps+1, (len(x),), device=device)
 
        # ===== 关键改动:随机丢弃标签 =====
        # 10% 的概率进行无条件训练,让模型学会没有标签时也能去噪
        if np.random.random() < 0.1:
            labels = None
 
        x_noisy, noise = diffuser.add_noise(x, t)
        noise_pred = model(x_noisy, t, labels)
        loss = F.mse_loss(noise, noise_pred)
 
        loss.backward()
        optimizer.step()
 
        loss_sum += loss.item()
        cnt += 1
 
    loss_avg = loss_sum / cnt
    losses.append(loss_avg)
    print(f'Epoch {epoch} | Loss: {loss_avg:.4f}')
 
# 最终生成展示
images, labels = diffuser.sample(model, gamma=3.0)
show_images(images, labels)

运行结果解读: 你可以尝试修改 sample() 函数中的 gamma 参数来感受它的魔力:

  • gamma = 1.0

    :相当于不加引导,模型自由发挥。

  • gamma = 3.0

    :模型比较听指令,生成结果与标签高度一致。

  • gamma = 5.0

    :极度听从指令,但可能会牺牲一些图像的自然度和多样性。

这种一拉滑块就能控制“听话程度”的体验,就是 CFG 最迷人的地方。


五、登堂入室:从MNIST到Stable Diffusion的广阔天地

我们已经从零搭建了一个能听懂数字指令的MNIST手写体生成器,但这只是万里长征的第一步。当我们放眼现代顶尖的AI绘画系统(如 Stable Diffusion),会发现它们虽然体量巨大,但其底层控制逻辑与我们今天搭建的模型惊人地相似。

  • 在像素空间运行太慢了

    :直接在 1024×1024 大小的图像上计算,对算力的消耗是不可思议的。

  • 潜在扩散模型(LDM)的解决方案

    :先将图像压缩到一个只有原图几十分之一大小的潜在空间(Latent Space),在压缩空间里进行所有复杂的扩散与去噪计算,最后再解压回原始尺寸。

  • 文本编码器的进化

    :我们用的是简单的nn.Embedding数字标签,而现代模型通常使用 CLIP 等大规模预训练模型作为文本编码器,将任何自然语言(比如“一只穿宇航服的柴犬”)转换成模型能理解的向量。

  • ControlNet与生成控制力天花板

    :利用 ControlNet 等附加控制模块,你甚至可以通过边缘图、深度图甚至人体姿态骨架来精确控制图像的构图和内容。


六、总结

今天我们完成了一段从理论到实践的完整旅程:

  • 理解核心原理

    :条件扩散模型通过在去噪网络中添加额外的条件输入(如文本、标签等),实现了可控的AI图像生成。

  • 亲手实现模型

    :我们用PyTorch从零搭建了一个带条件U-Net的扩散模型,成功实现了MNIST数字的条件生成。

  • 掌握无分类器指引

    :我们深入剖析了CFG技术的原理,并实现了通过一个γ系数就能精准控制“听话程度”的强大功能。

  • 展望未来世界

    :以Stable Diffusion为代表的现代模型,利用潜在空间扩散和文本编码器实现了更高分辨率、更强大的可控生成能力。

你已经不再是AI绘画的门外汉,而是一个掌握了核心底层技术的搭建者。现在,去创造属于你自己的“AI画图大师”吧!