21-DiT详解:扩散模型遇上Transformer的图像生成革命
引言
DiT(Diffusion Transformer)是Meta AI在2023年提出的突破性工作,它用纯Transformer架构实现扩散模型,在ImageNet 256×256生成任务上达到了FID 2.27的业界最佳水平,并首次在图像生成模型中展现出清晰的scaling law特性。
本文目标:深入理解DiT的四个核心组件(Patchify、向量化、位置编码、AdaLN-Zero)、推理机制和训练过程。
适合人群:了解Transformer基础和扩散模型原理的读者。
第一部分:扩散模型的数学基础
前向过程:从图像到噪声
扩散模型的前向过程是一个马尔可夫链,逐步向图像添加高斯噪声:
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
关键性质是可以直接从 x0 跳到任意 xt:
xt=αˉtx0+1−αˉtε
其中 αˉt=∏i=1t(1−βi),ε∼N(0,I)。
反向过程:学习去噪
训练目标是学习一个神经网络 εθ(xt,t) 预测噪声:
L=Et,x0,ε[∥ε−εθ(xt,t)∥2]
DiT就是 εθ 的具体实现。
第二部分:DiT的四大核心组件
DiT的核心思想是将图像视为token序列,用Transformer处理。整个架构包含四个关键设计:
组件一:Patchify(切块) - 从2D到1D的转换
Patchify的本质
Patchify是将2D图像转换为1D token序列的过程。这是将Transformer应用于图像的前提。
给定图像 x∈RH×W×C,选择patch大小 p(通常16或8),将图像切分成 N=p2HW 个不重叠的patch。
每个patch是一个 p×p×C 的立方体,flatten后得到 p2C 维向量。所有patch排列成序列:
Patches∈RN×(p2C)
为什么Patchify是合理的?
局部性原理:自然图像具有强局部相关性。一个 16×16 的patch(256像素)通常包含一个完整的局部语义单元。
计算效率的权衡:
- 逐像素处理:256×256 图像有65536个token,自注意力复杂度 O(N2)=O(4.3×109)
- patch大小 p=16:只有256个token,复杂度 O(6.5×104),降低了6.6万倍
信息无损:切块是可逆操作,不丢失任何像素信息。
Patch排列的顺序
DiT采用光栅扫描顺序(raster-scan order):从左到右、从上到下依次排列。
虽然Transformer的自注意力是位置不变的(打乱patch顺序输出也会相应打乱),但通过位置编码可以让模型理解patch的空间位置关系。
Patchify的深层意义
Patchify不仅是技术手段,更是认知范式的转变:
- CNN的视角:图像是2D网格,通过卷积核滑动提取局部特征
- Transformer的视角:图像是patch的集合,每个patch通过全局注意力与其他patch交互
这种转变使得模型可以直接建模长距离依赖,而不受卷积感受野的限制。
组件二:Linear Projection(向量化) - 从像素到语义
Embedding的数学定义
Linear Projection将每个patch从原始像素空间映射到高维语义空间。
zi=E⋅vec(patchi)+b
其中:
- E∈Rd×(p2C) 是投影矩阵(可学习)
- d 是Transformer的隐藏维度(如768、1024)
- vec 表示将patch展平成向量
所有patch embedding组成序列:
Z=[z1,z2,…,zN]∈RN×d
为什么需要Projection?
1. 维度标准化
不同的patch大小导致不同的输入维度:
- p=8:82×3=192 维
- p=16:162×3=768 维
投影到统一的 d 维,使得模型架构与patch大小解耦,提供了架构的灵活性。
2. 语义提升
原始像素值(如RGB=[125, 200, 89])是低层次的信号,投影矩阵 E 学习将其映射到高层次语义空间。
类比:Word Embedding将离散的词ID映射到连续的语义向量空间,Patch Embedding做的是类似的事情。
3. 计算效率
实践中,Linear Projection通常用卷积层实现:
Conv2D(k=p,s=p,in=C,out=d)
这等价于对每个patch做矩阵乘法,但利用了卷积的并行计算优势。
Projection的初始化
投影矩阵的初始化对训练至关重要。DiT使用Xavier初始化:
E∼U(−p2C+d6,p2C+d6)
这保证了初始时每层的激活值方差相近,避免梯度消失/爆炸。
组件三:Positional Encoding(位置编码) - 告诉模型"哪里"
为什么Transformer必须有位置编码?
Transformer的自注意力机制是置换等变的(permutation equivariant):
Attention(shuffle(X))=shuffle(Attention(X))
这意味着如果打乱输入顺序,输出也会相应打乱。Transformer本身无法区分patch的位置。
但图像任务中,位置信息极其关键:
- 天空通常在上方,草地在下方
- 物体的空间关系("猫在沙发上")依赖于位置理解
因此必须显式注入位置信息。
DiT的2D正弦位置编码
DiT采用固定的2D正弦位置编码(inherited from ViT)。
对于位置 (i,j) 的patch(第i行,第j列),其位置编码是:
PE(i,j)=[PEx(i),PEy(j)]
其中x和y坐标分别编码为:
PEx(i,2k)=sin(100002k/di)
PEx(i,2k+1)=cos(100002k/di)
最终的2D位置编码是x和y编码的拼接:
PE2D(i,j)∈Rd
前 d/2 维编码x坐标,后 d/2 维编码y坐标。
正弦位置编码的数学优势
1. 周期性与连续性
正弦函数是连续平滑的,相邻位置的编码向量相近,这符合图像的空间连续性假设。
2. 相对位置的可表达性
通过三角恒等式:
sin(α+β)=sinαcosβ+cosαsinβ
模型可以从绝对位置编码中推导出相对位置关系。例如,位置 (i+1,j) 的编码可以通过位置 (i,j) 的编码线性表示。
3. 外推能力
理论上,正弦编码可以泛化到训练时未见过的更大图像尺寸。虽然实践中效果有限,但这是可学习位置编码不具备的特性。
4. 参数效率
位置编码是固定的(不参与训练),节省了 N×d 个参数。
位置编码的注入:加法 vs 拼接
DiT使用加法注入:
Zwith_pos=Z+PE
为什么不用拼接?
- 加法:RN×d+RN×d=RN×d,维度不变
- 拼接:[RN×d;RN×d]=RN×2d,计算量翻倍
理论上,如果 d 足够大,加法空间就足以让模型将"内容"和"位置"信息解耦。
实际上,这是一个线性子空间分解的假设:
Z+PE=Zcontent+Zposition
模型通过学习将混合的信息分离到不同的子空间。
组件四:AdaLN-Zero - 条件注入的核心创新
AdaLN-Zero是DiT最重要的创新,解决了"如何将时间步t和类别c注入Transformer"这一核心问题。
扩散模型的条件注入难题
扩散模型需要接收两类信息:
- 内容信息:噪声图像 xt
- 条件信息:
- 时间步 t:当前处于扩散过程的哪个阶段(关键!)
- 类别标签 c:生成什么类别的图像
传统方法:
- 加法注入:x+f(t,c) —— 太简单,条件易被覆盖
- 拼接注入:[x;f(t,c)] —— 增加序列长度,计算量增大
- Cross-Attention:将条件作为Key/Value —— 复杂度高 O(N×M)
DiT提出了AdaLN(Adaptive Layer Normalization),一种高效且表达力强的方案。
Adaptive Layer Normalization的数学原理
标准的Layer Normalization:
LN(x)=γ⊙σx−μ+β
其中 γ,β 是固定的可学习参数。
AdaLN的核心思想:让 γ,β 依赖于条件信息:
γ(c),β(c)=MLP(c)
AdaLN(x,c)=γ(c)⊙σx−μ+β(c)
其中 c=f(t,c) 是时间步和类别的嵌入向量。
直观理解:调制(Modulation)
AdaLN本质上是用条件信息调制特征的分布。
- γ(c):控制特征的尺度(scale)
- β(c):控制特征的偏移(shift)
不同的条件 c 产生不同的 γ,β,从而引导网络产生不同的输出。
类比:想象一个收音机,条件信息是调频旋钮,γ,β 是调制信号,特征 x 是被调制的载波。
AdaLN-Zero:Zero Initialization的关键改进
DiT在AdaLN基础上加入了Zero Initialization,这是训练稳定性的核心。
标准的DiT Block结构:
h1=x+α1(c)⊙Attention(AdaLN(x,c))
h2=h1+α2(c)⊙MLP(AdaLN(h1,c))
其中 α1,α2 是门控参数,也由条件生成:
[γ1,β1,α1,γ2,β2,α2]=MLPmodulation(c)
Zero Initialization的定义:
MLPmodulation=W2⋅SiLU(W1c+b1)+b2
初始化时:
W2=0,b2=0
这保证了训练初始时:
γ1=γ2=1,β1=β2=0,α1=α2=0
因此:
h1=x+0⋅Attention(⋯)=x
h2=x+0⋅MLP(⋯)=x
整个网络初始时是恒等映射:f(x)=x。
为什么Zero Initialization如此重要?
1. 梯度流动的畅通性
深度网络训练的核心挑战是梯度消失/爆炸。
在恒等映射下,梯度可以无损地反向传播:
∂x∂L=∂h2∂L⋅∂x∂h2=∂h2∂L⋅I
其中 I 是单位矩阵,梯度直接传递,不会衰减。
2. 从简单到复杂的学习路径
随着训练进行,门控参数 α1,α2 从0逐渐增大,模型逐步学习利用注意力和MLP的输出。
这是一种curriculum learning(课程学习)策略:先学简单的(恒等映射),再学复杂的(注意力模式)。
3. 残差连接的极致体现
残差连接(ResNet)的核心公式:
h=x+F(x)
当 F(x)=0 时,网络退化为恒等映射,保证了至少不会比浅层网络差。
AdaLN-Zero通过zero initialization,强制初始时F(x)=0,这是残差思想的最彻底实践。
AdaLN vs 其他条件注入方式
| 方法 | 计算复杂度 | 表达能力 | 训练稳定性 |
|---|
| 加法注入 | O(1) | 弱 | 中 |
| 拼接注入 | O(N) | 中 | 中 |
| Cross-Attention | O(N⋅M) | 强 | 中 |
| AdaLN-Zero | O(1) | 强 | 优 |
AdaLN-Zero的优势:
- 零额外计算:不增加序列长度,不增加注意力计算
- 强表达力:通过调制归一化参数,影响每一层的特征分布
- 训练稳定:zero initialization保证梯度流畅通
第三部分:DiT的推理过程
推理就是从纯噪声逐步去噪,生成清晰图像。
DDPM采样:严格的概率过程
DDPM(Denoising Diffusion Probabilistic Models)是最原始的采样算法,严格遵循扩散模型的概率推导。
单步去噪公式:
xt−1=αt1(xt−1−αˉt1−αtεθ(xt,t,c))+σtz
其中:
- εθ(xt,t,c) 是DiT预测的噪声
- z∼N(0,I) 是新采样的随机噪声
- σt=β~t 是后验方差
完整流程:
- 初始化 xT∼N(0,I)(纯高斯噪声)
- 对于 t=T,T−1,…,1:
- 前向传播DiT:ε^=εθ(xt,t,c)
- 计算均值:μt=αt1(xt−1−αˉt1−αtε^)
- 采样噪声:z∼N(0,I)
- 更新:xt−1=μt+σtz
- 返回 x0
特点:
- 慢:需要1000步,每步都要前向传播DiT
- 质量高:每步添加适量随机性,生成多样性好
- 理论清晰:严格遵循后验分布 q(xt−1∣xt,x0)
DDIM采样:确定性加速
DDIM(Denoising Diffusion Implicit Models)通过确定性过程实现加速。
核心思想:不采样新噪声,而是走确定性的"直线路径"。
DDIM公式:
xt−1=αˉt−1predicted x0αˉtxt−1−αˉtε^+1−αˉt−1ε^
这个公式的直观理解:
- 用当前 xt 和预测噪声 ε^,估计干净图像:
x^0=αˉtxt−1−αˉtε^
- 用 x^0 和 ε^,重新组合成 xt−1:
xt−1=αˉt−1x^0+1−αˉt−1ε^
关键区别:
- DDPM:每步采样新噪声 z,引入随机性
- DDIM:重复使用同一个噪声估计 ε^,是确定性过程
加速原理:由于是确定性的,可以跳步采样。例如:
- DDPM:1000→999→998→⋯→1→0(1000步)
- DDIM:1000→950→900→⋯→50→0(20步)
特点:
- 快:50步达到DDPM 1000步的效果,速度提升20倍
- 确定性:相同初始噪声和条件,生成完全相同的图像
- 质量略降:FID略高于DDPM,但肉眼难以区分
Classifier-Free Guidance:提升条件遵循度
CFG(Classifier-Free Guidance)是提升生成质量的关键技术。
问题:标准条件生成可能"不够听话"。指定生成"猫",模型可能生成模糊的猫,或混合其他动物特征。
解决方案:训练时同时学习条件生成和无条件生成,推理时"放大"条件影响。
CFG训练
训练时,以概率 p=0.1 将类别标签置空:
c′={∅c概率 0.1概率 0.9
其中 ∅ 用特殊token表示(如类别ID=1000)。
这样模型学会了两种模式:
- εθ(xt,t,c):给定类别c的条件生成
- εθ(xt,t,∅):无条件生成
CFG推理
推理时,将两者线性组合:
ε~=εθ(xt,t,∅)+w⋅(εθ(xt,t,c)−εθ(xt,t,∅))
其中 w 是guidance scale(通常w=7.5)。
数学直观:
- εθ(xt,t,c)−εθ(xt,t,∅):条件相对于无条件的"差异方向"
- w>1:沿着这个方向走得更远,放大条件影响
- w=1:标准条件生成
- w=0:无条件生成
效果:
| Guidance scale w | 类别一致性 | 图像多样性 | 图像质量 |
|---|
| 1.0 | 低 | 高 | 一般 |
| 3.0-5.0 | 中 | 中 | 好 |
| 7.5 (推荐) | 高 | 中低 | 最好 |
| 15.0+ | 过高 | 低 | 过饱和、失真 |
代价:CFG需要推理两次(条件+无条件),推理时间翻倍。但效果提升显著,是工业标准。
完整推理流程
结合DDIM和CFG:
输入:
- 类别 c(如"猫"的ID=281)
- 采样步数 S=50
- Guidance scale w=7.5
算法:
1. 确定时间步序列:τ = [1000, 950, 900, ..., 50, 0](均匀采样S步)
2. 初始化:x ← N(0, I)
3. For t in τ[:-1]:
t_next ← τ中t的下一个时间步
# 条件预测
ε_cond ← DiT(x, t, c)
# 无条件预测
ε_uncond ← DiT(x, t, ∅)
# CFG组合
ε̂ ← ε_uncond + w * (ε_cond - ε_uncond)
# 估计x₀
x̂₀ ← (x - √(1-ᾱₜ)·ε̂) / √ᾱₜ
# DDIM更新
x ← √ᾱₜ_ₙₑₓₜ · x̂₀ + √(1-ᾱₜ_ₙₑₓₜ) · ε̂
4. Return x
时间成本(DiT-XL,A100 GPU):
- DDPM 1000步:约60秒/图
- DDIM 50步 + CFG:约6秒/图
第四部分:DiT的训练过程
训练目标
DiT的训练是简单的噪声预测任务:
L=Et,x0,ε,c[∥ε−εθ(xt,t,c)∥2]
其中 xt=αˉtx0+1−αˉtε。
训练算法
单个训练step:
1. 采样一批数据:(x₀, c) ~ 数据集
2. 采样时间步:t ~ Uniform(1, T)
3. 采样噪声:ε ~ N(0, I)
4. 前向加噪:xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε
5. 预测噪声:ε̂ = DiT(xₜ, t, c)
6. 计算损失:L = ‖ε̂ - ε‖²
7. 反向传播:更新参数θ
关键训练细节
1. 噪声调度(Noise Schedule)
βt 的设计影响训练效果。DiT使用线性调度:
βt=βmin+T−1t−1(βmax−βmin)
典型值:βmin=0.0001,βmax=0.02,T=1000。
这意味着:
- 早期(t小):βt很小,加噪缓慢,图像几乎不变
- 后期(t大):βt接近0.02,加噪快速,图像迅速变成纯噪声
2. Classifier-Free Guidance训练
如前所述,训练时10%概率drop类别:
c′={null tokencp=0.1p=0.9
这让模型同时学会两种生成模式。
3. 学习率调度
DiT使用warmup + cosine decay:
ηt=ηmin+21(ηmax−ηmin)(1+cos(πT−twt−tw))
其中:
- tw=10000:warmup步数
- ηmax=10−4:峰值学习率
- ηmin=10−5:最小学习率
前10000步线性增长到 ηmax,之后按余弦函数衰减。
原理:warmup避免初期梯度过大导致发散;cosine decay比step decay更平滑。
4. EMA(Exponential Moving Average)
维护参数的指数移动平均:
θEMA←μθEMA+(1−μ)θ
其中 μ=0.9999。
推理时使用 θEMA 而非 θ。
原理:EMA相当于对训练轨迹上的多个checkpoint做平滑,减少单个模型的抖动,提升生成质量和稳定性。
5. 混合精度训练
使用FP16计算,同时维护FP32主权重:
收益:
- 训练速度提升1.5-2倍
- 显存占用减半
- 精度损失可忽略
训练规模与成本
DiT-XL的训练配置:
| 项目 | 数值 |
|---|
| 参数量 | 675M |
| 数据集 | ImageNet(130万图像,1000类) |
| Batch size | 256(8卡 × 32/卡) |
| 训练步数 | 7M steps |
| 训练时长 | 约1个月(8×A100 80GB) |
| 总计算量 | 约10M GPU-hours |
| FID(256×256) | 2.27 |
Scaling Law:DiT的惊人发现
DiT首次在图像生成模型中展现出清晰的scaling law:
| 模型 | 参数量 | 深度 | 宽度 | FID ↓ |
|---|
| DiT-S | 33M | 12层 | 384 | 9.62 |
| DiT-B | 130M | 12层 | 768 | 5.31 |
| DiT-L | 458M | 24层 | 1024 | 3.04 |
| DiT-XL | 675M | 28层 | 1152 | 2.27 |
关键观察:
- 性能持续提升:从DiT-S到DiT-XL,FID持续下降,没有饱和迹象
- 对数线性关系:FID与log(参数量)近似线性关系
- 类似LLM:这与语言模型的scaling law特性一致
意义:
- 更大的模型 → 更好的生成质量(确定性规律)
- 为投资更大模型提供了理论依据
- 预示着10B+参数的扩散模型可能带来质的飞跃
总结:DiT的意义与启示
核心贡献
1. 架构统一
证明了Transformer可以作为扩散模型的通用backbone,图像生成不再需要特定领域的架构设计。
2. AdaLN-Zero
提出了优雅的条件注入机制,在零额外计算成本下实现强大的表达能力和训练稳定性。
3. Scaling Law
首次在图像生成中展现scaling特性,为"训练更大模型"提供了理论支持。
4. 性能突破
FID 2.27(256×256 ImageNet),超越所有基于卷积的方法。
DiT的局限
1. 计算复杂度:自注意力是 O(N2),分辨率越高越慢
2. 推理时间:即使DDIM,仍需50步,比单次前向慢50倍
3. 数据需求:需要大规模数据(百万级)才能充分发挥scaling优势
4. 条件类型:目前主要支持类别标签,对长文本支持有限
未来方向
1. 更高效的注意力:Sparse Attention、Linear Attention、Flash Attention
2. 更快的采样:Consistency Models(一步生成)、Latent Diffusion(低维空间扩散)
3. 更大的模型:DiT-XXL(10B参数)在更大数据集上训练
4. 多模态扩展:文本到图像、视频生成、3D生成
关键启示
- Transformer的通用性:不仅NLP,CV也适用
- Scaling的威力:更大的模型带来更好的效果
- 架构细节的重要性:AdaLN-Zero这样的创新带来质的提升
- 条件注入的本质:如何注入比注入什么更重要
DiT代表了扩散模型从CNN到Transformer的范式转变,这与NLP从RNN到Transformer的转变如出一辙。
Transformer + Diffusion = 图像生成的未来。
参考文献
- Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
- Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
- Ho, J., & Salimans, T. (2022). Classifier-Free Diffusion Guidance. NeurIPS Workshop 2021.
- Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.