你有没有想过,AI是如何听懂你的指令,画出你想要的东西的?当你对Midjourney输入“一只穿着宇航服的柴犬”,它真的能生成那张图——这背后究竟发生了什么?
今天,我将带你亲手实现一个基础的文本控制AI绘图系统。虽然我们做的是“数字0~9”的控制,但原理,和那些动辄几十亿参数的大模型,完全一致。
一、更上一层楼:让AI听懂你的“命令”
在之前的项目中,我们的扩散模型虽然能生成MNIST手写数字,但它是完全不受控的——你无法告诉它“我要一个数字5”,它生成什么全凭运气。
1.1 核心思维:从无条件到有条件
想象一下你请一个画家画画:
-
无条件扩散模型
:你告诉画家“随便画点啥”。他画什么你都只能接受,完全看他的心情。
-
条件扩散模型
:你告诉画家“给我画一个数字8”。他听懂了你的指令,专门为你创作一个8。
这就是条件扩散模型的核心思想——我们在神经网络中引入了一个额外的输入,也就是条件y,告诉模型“我想要什么”。
条件y可以是各种各样的东西:一个数字标签(就像我们今天要做的MNIST手写数字)、一段文本描述、一张低分辨率的图像(这就是超分辨率技术),甚至是边缘检测图或姿态关键点。
1.2 简单的实现思路
如何让模型消化这个“条件”呢?关键是把y变成它能理解的数学形式。
-
神经网络不认识“5”这个整数,就像你不认识外星文一样。
-
我们需要一个翻译官,把“5”翻译成神经网络能理解的向量。
-
这个翻译官,在深度学习里叫做嵌入层(Embedding Layer) 。
具体的实现思路是这样的:
-
我们首先仍然使用正弦位置编码,把时间步信息t(比如当前是第100步去噪)变成模型能理解的向量。
-
然后,我们也用一个嵌入层,把输入的指令y(比如数字“5”)也变成一个特征向量。
-
最后,简单粗暴但极其有效:把这两个向量相加!此时,模型接收到的信息就同时包含了“时间信息”和“用户指令”。模型自然就知道,在这个时间点,它应该朝着“生成5”的方向去努力了。
这个方法的优雅之处在于,它对原始模型结构的改动极其微小,但效果却立竿见影。
二、代码实战(一):打造能听懂指令的AI画师
纸上得来终觉浅,绝知此事要躬行。下面我们动手实现上面讲的条件扩散模型。
(这里有一个细节需要注意:在本节及后续的所有代码实现中,DDPM的前向加噪过程使用了“累积极大值” \bar{\alpha}t;而在反向去噪计算 \mu\theta 时则使用“原始值” \alpha_t 与上一节的数学推导完全一致。denoise() 函数也必须同时接收这两个参数。
import math
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
# ========== 超参数设置 ==========
img_size = 28 # MNIST图像尺寸 28x28
batch_size = 128 # 批次大小
num_timesteps = 1000 # 扩散步数(DDPM的标准配置)
epochs = 10 # 训练轮数
lr = 1e-3 # 学习率
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ========== 辅助函数 ==========
def show_images(images, labels=None, rows=2, cols=10):
"""展示生成的图像(带标签)"""
fig = plt.figure(figsize=(cols, rows))
i = 0
for r in range(rows):
for c in range(cols):
ax = fig.add_subplot(rows, cols, i + 1)
plt.imshow(images[i], cmap='gray')
if labels is not None:
ax.set_xlabel(labels[i].item())
ax.get_axes().set_ticklabels([])
ax.get_axes().set_ticks([])
i += 1
plt.tight_layout()
plt.show()
def _pos_encoding(time_idx, output_dim, device='cpu'):
"""为单个时间步生成正弦位置编码"""
t, D = time_idx, output_dim
v = torch.zeros(D, device=device)
i = torch.arange(0, D, device=device)
# 关键的计算公式:div_term = 10000^(2i/D)
div_term = torch.exp(i / D * math.log(10000))
# 偶数位用正弦,奇数位用余弦
v[0::2] = torch.sin(t / div_term[0::2])
v[1::2] = torch.cos(t / div_term[1::2])
return v
def pos_encoding(timesteps, output_dim, device='cpu'):
"""为批次中的所有时间步生成正弦位置编码"""
batch_size = len(timesteps)
device = timesteps.device
v = torch.zeros(batch_size, output_dim, device=device)
for i in range(batch_size):
v[i] = _pos_encoding(timesteps[i], output_dim, device=device)
return v
# ========== 卷积块(带时间嵌入) ==========
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_embed_dim):
super().__init__()
# 双卷积层:Conv -> BN -> ReLU -> Conv -> BN -> ReLU
self.convs = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
# MLP将时间嵌入映射到合适的特征维度
self.mlp = nn.Sequential(
nn.Linear(time_embed_dim, in_ch),
nn.ReLU(),
nn.Linear(in_ch, in_ch)
)
def forward(self, x, v):
N, C, _, _ = x.shape
# 时间嵌入经过MLP后 reshape 为 (N, C, 1, 1)
v = self.mlp(v)
v = v.view(N, C, 1, 1)
# 将时间嵌入加到输入上(特征调制)
y = self.convs(x + v)
return y
# ========== 条件U-Net模型 ==========
class UNetCond(nn.Module):
def __init__(self, in_ch=1, time_embed_dim=100, num_labels=None):
super().__init__()
self.time_embed_dim = time_embed_dim
# U-Net的编码器(下采样路径)
self.down1 = ConvBlock(in_ch, 64, time_embed_dim) # 28 -> 28(保留尺寸经池化->14)
self.down2 = ConvBlock(64, 128, time_embed_dim) # 14 -> 14(保留尺寸经池化->7)
# 瓶颈层(最低分辨率)
self.bot1 = ConvBlock(128, 256, time_embed_dim) # 7 -> 7
# 解码器(上采样路径)
self.up2 = ConvBlock(128 + 256, 128, time_embed_dim) # 7 -> 14(拼接来自down2的特征)
self.up1 = ConvBlock(128 + 64, 64, time_embed_dim) # 14 -> 28(拼接来自down1的特征)
# 输出层
self.out = nn.Conv2d(64, in_ch, 1) # 1x1卷积输出噪声预测
self.maxpool = nn.MaxPool2d(2) # 2倍下采样
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') # 2倍上采样
# ========== 关键!:处理标签的嵌入层 ==========
if num_labels is not None:
# 将整数标签(0-9)转换为 time_embed_dim 维的向量
self.label_emb = nn.Embedding(num_labels, time_embed_dim)
def forward(self, x, timesteps, labels=None):
# 1. 将时间步转换为正弦位置编码
t = pos_encoding(timesteps, self.time_embed_dim)
# 2. 如果有标签,将标签转换为嵌入并加到时间编码上
if labels is not None:
# label_emb(labels) 的形状是 (batch_size, time_embed_dim)
# 直接加到时间编码上,两个信号融合
t += self.label_emb(labels)
# 3. U-Net 前向传播
# 编码器路径
x1 = self.down1(x, t) # 保存用于跳跃连接
x = self.maxpool(x1) # 下采样
x2 = self.down2(x, t) # 保存用于跳跃连接
x = self.maxpool(x2) # 下采样
# 瓶颈层
x = self.bot1(x, t)
# 解码器路径(带跳跃连接)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1) # 拼接(跳跃连接)
x = self.up2(x, t)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1) # 拼接(跳跃连接)
x = self.up1(x, t)
# 输出噪声预测
x = self.out(x)
return x
去噪扩散封装器(Diffuser)
封装正向加噪与反向去噪流程。
class Diffuser:
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
self.num_timesteps = num_timesteps
self.device = device
# 线性噪声调度(beta从0.0001线性增加到0.02,DDPM论文的原始配置)
self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
self.alphas = 1 - self.betas # alpha_t = 1 - beta_t
self.alpha_bars = torch.cumprod(self.alphas, dim=0) # \bar{alpha}_t = Π alpha_{1..t}
def add_noise(self, x_0, t):
"""前向扩散:向干净图像添加噪声,得到 x_t"""
# t 从 1 到 T,索引需要 -1 才能对齐 alpha_bars[0] 对应 t=1
t_idx = t - 1
alpha_bar = self.alpha_bars[t_idx]
# reshape 为 (N, 1, 1, 1) 用于广播
alpha_bar = alpha_bar.view(alpha_bar.size(0), 1, 1, 1)
# 生成高斯噪声,并与干净图像按公式混合
noise = torch.randn_like(x_0, device=self.device)
# x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
return x_t, noise
def denoise(self, model, x, t, labels):
"""反向扩散:从 x_t 去噪得到 x_{t-1}"""
t_idx = t - 1
# 获取当前时间步的关键参数
alpha = self.alphas[t_idx] # alpha_t
alpha_bar = self.alpha_bars[t_idx] # \bar{alpha}_t
alpha_bar_prev = self.alpha_bars[t_idx-1] # \bar{alpha}_{t-1}(t=1时自动处理)
N = alpha.size(0)
alpha = alpha.view(N, 1, 1, 1)
alpha_bar = alpha_bar.view(N, 1, 1, 1)
alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
# 使用模型预测噪声
model.eval()
with torch.no_grad():
eps = model(x, t, labels) # 【关键】:同时传入 labels!
model.train()
# 计算去噪均值 mu
# mu = (x - ( (1 - alpha) / sqrt(1 - alpha_bar) ) * eps) / sqrt(alpha)
mu = (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
std = torch.sqrt((1 - alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
# 添加噪声(DDPM采样时的随机性)
noise = torch.randn_like(x, device=self.device)
noise[t == 1] = 0 # t=1时是最后一步,不添加噪声
return mu + noise * std
def reverse_to_img(self, x):
"""将张量数据转换为可显示的 PIL 图像"""
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
x = x.cpu()
to_pil = transforms.ToPILImage()
return to_pil(x)
def sample(self, model, x_shape=(20, 1, 28, 28), labels=None):
"""从随机噪声开始,逐步去噪生成图像"""
batch_size = x_shape[0]
x = torch.randn(x_shape, device=self.device) # 纯随机噪声开始
if labels is None:
# 如果没给标签,就随机生成0~9的标签
labels = torch.randint(0, 10, (batch_size,), device=self.device)
# 从 T 步逐步去噪到 1 步
for i in tqdm(range(self.num_timesteps, 0, -1)):
t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
x = self.denoise(model, x, t, labels)
# 转换格式并返回
images = [self.reverse_to_img(x[i]) for i in range(batch_size)]
return images, labels
# ========== 数据加载 ==========
preprocess = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='../datasets', download=True, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# ========== 初始化模型和优化器 ==========
diffuser = Diffuser(num_timesteps, device=device)
model = UNetCond(num_labels=10) # 10个类别(数字0-9)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
# ========== 训练循环 ==========
losses = []
for epoch in range(epochs):
loss_sum = 0.0
cnt = 0
# 每个 epoch 结束后生成一组图像,观察训练进展
images, labels = diffuser.sample(model)
show_images(images, labels)
for images, labels in tqdm(dataloader):
optimizer.zero_grad()
x = images.to(device)
labels = labels.to(device) # 【关键】:训练时也要提供标签!
t = torch.randint(1, num_timesteps+1, (len(x),), device=device)
# 添加噪声并预测
x_noisy, noise = diffuser.add_noise(x, t)
noise_pred = model(x_noisy, t, labels) # 【关键】:模型同时接收t和labels
loss = F.mse_loss(noise, noise_pred)
loss.backward()
optimizer.step()
loss_sum += loss.item()
cnt += 1
loss_avg = loss_sum / cnt
losses.append(loss_avg)
print(f'Epoch {epoch} | Loss: {loss_avg:.4f}')
# 最终生成展示
images, labels = diffuser.sample(model)
show_images(images, labels)
运行结果: 经过短短10轮训练,模型已经学会了根据标签生成对应的数字。虽然边缘还有些模糊,但它确实成功理解了你的“指令”。
三、AI的进阶之路:从得分函数到分类器指引
上一步的模型虽然能工作了,但它有时候会“偷懒”,不那么看重你给的条件,甚至可能会忽略。为了解决这个问题,我们需要引入一种更强大的技术,它的名字听起来很学术,但原理非常直观,这就是——指引(Guidance) 。
3.1 得分函数——AI内部的“导航仪”
在讨论指引之前,我们需要了解一下扩散模型内部是怎么工作的。扩散模型内部有一个重要的概念叫做得分函数,它是模型判断“这像不像一张真实图像”的内部标尺。数学上定义为对数概率密度相对于输入数据向量的梯度。
一句话理解得分函数:
想象你在一个黑暗的山谷里探索,你蒙着眼睛,目标是走到谷底。得分函数就像你脚下感知坡度的触觉——它会告诉你哪个方向是“下坡”,哪里是“上坡” 。模型就是循着这个“下坡方向”,一步步把噪声“修”成干净图像(数据点会自然聚集在概率高密度的谷底)。噪声预测模型 \epsilon_\theta(x_t, t) 本质上就是局部梯度的另一种表达形式(存在一个负常数倍关系),因此它其实就在扮演得分函数 s_\theta(x_t, t) 的角色。这也再次验证了扩散模型与基于得分的生成模型是高度统一的:对噪声的预测,本质上等效于对得分的预测。
3.2 分类器指引——给AI装上“GPS”
既然得分函数告诉模型“往哪边走是对的”,那如果我们用分类器告诉模型“往条件 y 的方向走”,不就行了吗?这正是分类器指引的思路。这条方向其实就是条件分类器对当前图像的梯度:\nabla_{x_t} \log p(y|x_t)。
-
无条件得分:\nabla_{x_t} \log p(x_t)(模型觉得哪条路自然)
-
分类器梯度:\nabla_{x_t} \log p(y|x_t)(分类器觉得哪条路更符合 y)
将这两股力量按公式“有条件得分=无条件得分+γ×分类器梯度”融合,模型就能在保持自然的同时,坚决朝着指令 y 前进。
缺点也很明显: 你必须额外训练一个独立的分类器。而且这个分类器要处理“加了噪声的模糊图”,和常规训练好的分类器很难完美兼容。
3.3 无分类器指引——一个模型干两份活
既然训练一个独立分类器这么麻烦,能不能用一个模型同时学会“无条件生成”和“有条件生成”,然后在生成时把两者结合起来?这就是大名鼎鼎的无分类器指引(Classifier-Free Guidance,简称CFG) 的核心思想。
原理其实简单到令人惊讶:我们在训练时,让模型以一定比例随机丢掉条件信息。比如10%的概率把labels设为None,让模型在这种情况下进行无条件训练;其余90%的概率正常传labels,进行有条件训练。
-
当传入labels=None时,模型只根据时间步去噪,学到的就是“无条件得分”。
-
当传入具体labels时,模型同时利用时间步和标签去噪,学到的就是“有条件得分”。
最后在生成时,CFG按以下公式将两者结合:
最终预测=无条件预测+γ×(有条件预测−无条件预测)最终预测=无条件预测+γ×(有条件预测−无条件预测)
-
\gamma(称为Guidance Scale)越大,模型就越“听话”,生成的图像更贴合你的指令。
-
\gamma 越小,模型就越“自由”,生成的图像更有创意和多样性。
这种方法的优势在于不依赖任何外部预训练分类器,只需一个模型,训练极其简单,生成时又能精准控制“听话”的程度。
补充两条进阶视角:
近年来,学术界仍在持续优化 CFG 的底层理论——例如 2025 年底的研究已开始分析声音信号与几何纠缠的根本原因
;另一组工作则专门解析线性 CFG 所内含的“均值偏移”和“类别特征放大”机制
。
你熟悉的
“反向提示词”技术
在数学上恰好对应这种无分类器指引:把不需要的信息对应的条件 p(y|x_t) 低维嵌入
Φ
放到无条件部分里,让模型在生成时“踩刹车”绕过它。
四、代码实战(二):无分类器指引的完整实现
现在我们把上面讲的理论转化为可运行的代码,看看 CFG 到底有多简单、多强大。
核心改动有三点:
-
训练时随机丢弃条件
:以一定概率(比如10%)将labels设为None,让模型在无条件模式下训练。
-
生成时使用CFG公式
:同时计算model(x, t, labels)(有条件预测)和model(x, t)(无条件预测),然后按 γ 系数混合。
-
配置可供调节的引导系数 γ
:γ 越大,生成结果越“听指令”;γ 越小,结果越有随机多样性。
import math
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
# ========== 超参数 ==========
img_size = 28
batch_size = 128
num_timesteps = 1000
epochs = 10
lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def show_images(images, labels=None, rows=2, cols=10):
fig = plt.figure(figsize=(cols, rows))
i = 0
for r in range(rows):
for c in range(cols):
ax = fig.add_subplot(rows, cols, i + 1)
plt.imshow(images[i], cmap='gray')
if labels is not None:
ax.set_xlabel(labels[i].item())
ax.get_axes().set_ticklabels([])
ax.get_axes().set_ticks([])
i += 1
plt.tight_layout()
plt.show()
def _pos_encoding(time_idx, output_dim, device='cpu'):
t, D = time_idx, output_dim
v = torch.zeros(D, device=device)
i = torch.arange(0, D, device=device)
div_term = torch.exp(i / D * math.log(10000))
v[0::2] = torch.sin(t / div_term[0::2])
v[1::2] = torch.cos(t / div_term[1::2])
return v
def pos_encoding(timesteps, output_dim, device='cpu'):
batch_size = len(timesteps)
device = timesteps.device
v = torch.zeros(batch_size, output_dim, device=device)
for i in range(batch_size):
v[i] = _pos_encoding(timesteps[i], output_dim, device)
return v
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_embed_dim):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
self.mlp = nn.Sequential(
nn.Linear(time_embed_dim, in_ch),
nn.ReLU(),
nn.Linear(in_ch, in_ch)
)
def forward(self, x, v):
N, C, _, _ = x.shape
v = self.mlp(v)
v = v.view(N, C, 1, 1)
y = self.convs(x + v)
return y
class UNetCond(nn.Module):
def __init__(self, in_ch=1, time_embed_dim=100, num_labels=None):
super().__init__()
self.time_embed_dim = time_embed_dim
self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
self.down2 = ConvBlock(64, 128, time_embed_dim)
self.bot1 = ConvBlock(128, 256, time_embed_dim)
self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
self.out = nn.Conv2d(64, in_ch, 1)
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
if num_labels is not None:
self.label_emb = nn.Embedding(num_labels, time_embed_dim)
def forward(self, x, timesteps, labels=None):
t = pos_encoding(timesteps, self.time_embed_dim)
if labels is not None:
t += self.label_emb(labels)
x1 = self.down1(x, t)
x = self.maxpool(x1)
x2 = self.down2(x, t)
x = self.maxpool(x2)
x = self.bot1(x, t)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.up2(x, t)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.up1(x, t)
x = self.out(x)
return x
# ========== 带 CFG 的 Diffuser ==========
class Diffuser:
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
self.num_timesteps = num_timesteps
self.device = device
self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def add_noise(self, x_0, t):
t_idx = t - 1
alpha_bar = self.alpha_bars[t_idx]
alpha_bar = alpha_bar.view(alpha_bar.size(0), 1, 1, 1)
noise = torch.randn_like(x_0, device=self.device)
x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
return x_t, noise
def denoise(self, model, x, t, labels, gamma):
"""带 CFG 的去噪函数 —— 最关键的部分!"""
t_idx = t - 1
alpha = self.alphas[t_idx]
alpha_bar = self.alpha_bars[t_idx]
alpha_bar_prev = self.alpha_bars[t_idx-1]
N = alpha.size(0)
alpha = alpha.view(N, 1, 1, 1)
alpha_bar = alpha_bar.view(N, 1, 1, 1)
alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
model.eval()
with torch.no_grad():
eps_cond = model(x, t, labels) # 有条件预测
eps_uncond = model(x, t) # 无条件预测
# CFG 核心公式:最终预测 = 无条件 + gamma * (有条件 - 无条件)
eps = eps_uncond + gamma * (eps_cond - eps_uncond)
model.train()
noise = torch.randn_like(x, device=self.device)
noise[t == 1] = 0
mu = (x - ((1-alpha) / torch.sqrt(1-alpha_bar)) * eps) / torch.sqrt(alpha)
std = torch.sqrt((1-alpha) * (1-alpha_bar_prev) / (1-alpha_bar))
return mu + noise * std
def reverse_to_img(self, x):
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
x = x.cpu()
to_pil = transforms.ToPILImage()
return to_pil(x)
def sample(self, model, x_shape=(20, 1, 28, 28), labels=None, gamma=3.0):
"""生成函数,带 CFG 的引导系数 gamma"""
batch_size = x_shape[0]
x = torch.randn(x_shape, device=self.device)
if labels is None:
labels = torch.randint(0, 10, (batch_size,), device=self.device)
for i in tqdm(range(self.num_timesteps, 0, -1)):
t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
x = self.denoise(model, x, t, labels, gamma)
images = [self.reverse_to_img(x[i]) for i in range(batch_size)]
return images, labels
# ========== 数据加载 ==========
preprocess = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='./data', download=True, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# ========== 初始化 ==========
diffuser = Diffuser(num_timesteps, device=device)
model = UNetCond(num_labels=10)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
# ========== 训练(关键:随机丢弃条件) ==========
losses = []
for epoch in range(epochs):
loss_sum = 0.0
cnt = 0
# 每轮结束后生成一次,观察 gamma 效果(可以尝试 gamma=1.5, 3.0, 5.0)
images, labels = diffuser.sample(model, gamma=3.0)
show_images(images, labels)
for images, labels in tqdm(dataloader):
optimizer.zero_grad()
x = images.to(device)
labels = labels.to(device)
t = torch.randint(1, num_timesteps+1, (len(x),), device=device)
# ===== 关键改动:随机丢弃标签 =====
# 10% 的概率进行无条件训练,让模型学会没有标签时也能去噪
if np.random.random() < 0.1:
labels = None
x_noisy, noise = diffuser.add_noise(x, t)
noise_pred = model(x_noisy, t, labels)
loss = F.mse_loss(noise, noise_pred)
loss.backward()
optimizer.step()
loss_sum += loss.item()
cnt += 1
loss_avg = loss_sum / cnt
losses.append(loss_avg)
print(f'Epoch {epoch} | Loss: {loss_avg:.4f}')
# 最终生成展示
images, labels = diffuser.sample(model, gamma=3.0)
show_images(images, labels)
运行结果解读: 你可以尝试修改 sample() 函数中的 gamma 参数来感受它的魔力:
-
gamma = 1.0
:相当于不加引导,模型自由发挥。
-
gamma = 3.0
:模型比较听指令,生成结果与标签高度一致。
-
gamma = 5.0
:极度听从指令,但可能会牺牲一些图像的自然度和多样性。
这种一拉滑块就能控制“听话程度”的体验,就是 CFG 最迷人的地方。
五、登堂入室:从MNIST到Stable Diffusion的广阔天地
我们已经从零搭建了一个能听懂数字指令的MNIST手写体生成器,但这只是万里长征的第一步。当我们放眼现代顶尖的AI绘画系统(如 Stable Diffusion),会发现它们虽然体量巨大,但其底层控制逻辑与我们今天搭建的模型惊人地相似。
-
在像素空间运行太慢了
:直接在 1024×1024 大小的图像上计算,对算力的消耗是不可思议的。
-
潜在扩散模型(LDM)的解决方案
:先将图像压缩到一个只有原图几十分之一大小的潜在空间(Latent Space),在压缩空间里进行所有复杂的扩散与去噪计算,最后再解压回原始尺寸。
-
文本编码器的进化
:我们用的是简单的nn.Embedding数字标签,而现代模型通常使用 CLIP 等大规模预训练模型作为文本编码器,将任何自然语言(比如“一只穿宇航服的柴犬”)转换成模型能理解的向量。
-
ControlNet与生成控制力天花板
:利用 ControlNet 等附加控制模块,你甚至可以通过边缘图、深度图甚至人体姿态骨架来精确控制图像的构图和内容。
六、总结
今天我们完成了一段从理论到实践的完整旅程:
-
理解核心原理
:条件扩散模型通过在去噪网络中添加额外的条件输入(如文本、标签等),实现了可控的AI图像生成。
-
亲手实现模型
:我们用PyTorch从零搭建了一个带条件U-Net的扩散模型,成功实现了MNIST数字的条件生成。
-
掌握无分类器指引
:我们深入剖析了CFG技术的原理,并实现了通过一个γ系数就能精准控制“听话程度”的强大功能。
-
展望未来世界
:以Stable Diffusion为代表的现代模型,利用潜在空间扩散和文本编码器实现了更高分辨率、更强大的可控生成能力。
你已经不再是AI绘画的门外汉,而是一个掌握了核心底层技术的搭建者。现在,去创造属于你自己的“AI画图大师”吧!