深度拆解 DiT:扩散模型与 Transformer 的巅峰结合

3 阅读4分钟

21-DiT详解:扩散模型遇上Transformer的图像生成革命

引言

DiT(Diffusion Transformer)是Meta AI在2023年提出的突破性工作,它用纯Transformer架构实现扩散模型,在ImageNet 256×256生成任务上达到了FID 2.27的业界最佳水平,并首次在图像生成模型中展现出清晰的scaling law特性。

本文目标:深入理解DiT的四个核心组件(Patchify、向量化、位置编码、AdaLN-Zero)、推理机制和训练过程。

适合人群:了解Transformer基础和扩散模型原理的读者。


第一部分:扩散模型的数学基础

前向过程:从图像到噪声

扩散模型的前向过程是一个马尔可夫链,逐步向图像添加高斯噪声:

q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I)

关键性质是可以直接从 x0x_0 跳到任意 xtx_t

xt=αˉtx0+1αˉtεx_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \varepsilon

其中 αˉt=i=1t(1βi)\bar{\alpha}_t = \prod_{i=1}^{t}(1-\beta_i)εN(0,I)\varepsilon \sim \mathcal{N}(0, I)

反向过程:学习去噪

训练目标是学习一个神经网络 εθ(xt,t)\varepsilon_\theta(x_t, t) 预测噪声:

L=Et,x0,ε[εεθ(xt,t)2]\mathcal{L} = \mathbb{E}_{t, x_0, \varepsilon}\left[\|\varepsilon - \varepsilon_\theta(x_t, t)\|^2\right]

DiT就是 εθ\varepsilon_\theta 的具体实现。


第二部分:DiT的四大核心组件

DiT的核心思想是将图像视为token序列,用Transformer处理。整个架构包含四个关键设计:


组件一:Patchify(切块) - 从2D到1D的转换

Patchify的本质

Patchify是将2D图像转换为1D token序列的过程。这是将Transformer应用于图像的前提。

给定图像 xRH×W×Cx \in \mathbb{R}^{H \times W \times C},选择patch大小 pp(通常16或8),将图像切分成 N=HWp2N = \frac{HW}{p^2} 个不重叠的patch。

每个patch是一个 p×p×Cp \times p \times C 的立方体,flatten后得到 p2Cp^2C 维向量。所有patch排列成序列:

PatchesRN×(p2C)\text{Patches} \in \mathbb{R}^{N \times (p^2C)}

为什么Patchify是合理的?

局部性原理:自然图像具有强局部相关性。一个 16×1616 \times 16 的patch(256像素)通常包含一个完整的局部语义单元。

计算效率的权衡

  • 逐像素处理:256×256256\times256 图像有65536个token,自注意力复杂度 O(N2)=O(4.3×109)O(N^2) = O(4.3 \times 10^9)
  • patch大小 p=16p=16:只有256个token,复杂度 O(6.5×104)O(6.5 \times 10^4),降低了6.6万倍

信息无损:切块是可逆操作,不丢失任何像素信息。

Patch排列的顺序

DiT采用光栅扫描顺序(raster-scan order):从左到右、从上到下依次排列。

虽然Transformer的自注意力是位置不变的(打乱patch顺序输出也会相应打乱),但通过位置编码可以让模型理解patch的空间位置关系。

Patchify的深层意义

Patchify不仅是技术手段,更是认知范式的转变

  • CNN的视角:图像是2D网格,通过卷积核滑动提取局部特征
  • Transformer的视角:图像是patch的集合,每个patch通过全局注意力与其他patch交互

这种转变使得模型可以直接建模长距离依赖,而不受卷积感受野的限制。


组件二:Linear Projection(向量化) - 从像素到语义

Embedding的数学定义

Linear Projection将每个patch从原始像素空间映射到高维语义空间。

zi=Evec(patchi)+b\mathbf{z}_i = \mathbf{E} \cdot \text{vec}(\text{patch}_i) + \mathbf{b}

其中:

  • ERd×(p2C)\mathbf{E} \in \mathbb{R}^{d \times (p^2C)} 是投影矩阵(可学习)
  • dd 是Transformer的隐藏维度(如768、1024)
  • vec\text{vec} 表示将patch展平成向量

所有patch embedding组成序列:

Z=[z1,z2,,zN]RN×d\mathbf{Z} = [\mathbf{z}_1, \mathbf{z}_2, \ldots, \mathbf{z}_N] \in \mathbb{R}^{N \times d}

为什么需要Projection?

1. 维度标准化

不同的patch大小导致不同的输入维度:

  • p=8p=882×3=1928^2 \times 3 = 192
  • p=16p=16162×3=76816^2 \times 3 = 768

投影到统一的 dd 维,使得模型架构与patch大小解耦,提供了架构的灵活性

2. 语义提升

原始像素值(如RGB=[125, 200, 89])是低层次的信号,投影矩阵 E\mathbf{E} 学习将其映射到高层次语义空间。

类比:Word Embedding将离散的词ID映射到连续的语义向量空间,Patch Embedding做的是类似的事情。

3. 计算效率

实践中,Linear Projection通常用卷积层实现:

Conv2D(k=p,s=p,in=C,out=d)\text{Conv2D}(k=p, s=p, \text{in}=C, \text{out}=d)

这等价于对每个patch做矩阵乘法,但利用了卷积的并行计算优势。

Projection的初始化

投影矩阵的初始化对训练至关重要。DiT使用Xavier初始化

EU(6p2C+d,6p2C+d)\mathbf{E} \sim \mathcal{U}\left(-\sqrt{\frac{6}{p^2C + d}}, \sqrt{\frac{6}{p^2C + d}}\right)

这保证了初始时每层的激活值方差相近,避免梯度消失/爆炸。


组件三:Positional Encoding(位置编码) - 告诉模型"哪里"

为什么Transformer必须有位置编码?

Transformer的自注意力机制是置换等变的(permutation equivariant):

Attention(shuffle(X))=shuffle(Attention(X))\text{Attention}(\text{shuffle}(X)) = \text{shuffle}(\text{Attention}(X))

这意味着如果打乱输入顺序,输出也会相应打乱。Transformer本身无法区分patch的位置

但图像任务中,位置信息极其关键:

  • 天空通常在上方,草地在下方
  • 物体的空间关系("猫在沙发上")依赖于位置理解

因此必须显式注入位置信息。

DiT的2D正弦位置编码

DiT采用固定的2D正弦位置编码(inherited from ViT)。

对于位置 (i,j)(i, j) 的patch(第ii行,第jj列),其位置编码是:

PE(i,j)=[PEx(i),PEy(j)]\text{PE}(i, j) = [\text{PE}_x(i), \text{PE}_y(j)]

其中x和y坐标分别编码为:

PEx(i,2k)=sin(i100002k/d)\text{PE}_x(i, 2k) = \sin\left(\frac{i}{10000^{2k/d}}\right)
PEx(i,2k+1)=cos(i100002k/d)\text{PE}_x(i, 2k+1) = \cos\left(\frac{i}{10000^{2k/d}}\right)

最终的2D位置编码是x和y编码的拼接:

PE2D(i,j)Rd\text{PE}_{2D}(i,j) \in \mathbb{R}^d

d/2d/2 维编码x坐标,后 d/2d/2 维编码y坐标。

正弦位置编码的数学优势

1. 周期性与连续性

正弦函数是连续平滑的,相邻位置的编码向量相近,这符合图像的空间连续性假设

2. 相对位置的可表达性

通过三角恒等式:

sin(α+β)=sinαcosβ+cosαsinβ\sin(\alpha + \beta) = \sin\alpha\cos\beta + \cos\alpha\sin\beta

模型可以从绝对位置编码中推导出相对位置关系。例如,位置 (i+1,j)(i+1, j) 的编码可以通过位置 (i,j)(i, j) 的编码线性表示。

3. 外推能力

理论上,正弦编码可以泛化到训练时未见过的更大图像尺寸。虽然实践中效果有限,但这是可学习位置编码不具备的特性。

4. 参数效率

位置编码是固定的(不参与训练),节省了 N×dN \times d 个参数。

位置编码的注入:加法 vs 拼接

DiT使用加法注入:

Zwith_pos=Z+PE\mathbf{Z}_{\text{with\_pos}} = \mathbf{Z} + \mathbf{PE}

为什么不用拼接?

  • 加法RN×d+RN×d=RN×d\mathbb{R}^{N \times d} + \mathbb{R}^{N \times d} = \mathbb{R}^{N \times d},维度不变
  • 拼接[RN×d;RN×d]=RN×2d[\mathbb{R}^{N \times d}; \mathbb{R}^{N \times d}] = \mathbb{R}^{N \times 2d},计算量翻倍

理论上,如果 dd 足够大,加法空间就足以让模型将"内容"和"位置"信息解耦。

实际上,这是一个线性子空间分解的假设:

Z+PE=Zcontent+Zposition\mathbf{Z} + \mathbf{PE} = \mathbf{Z}_{\text{content}} + \mathbf{Z}_{\text{position}}

模型通过学习将混合的信息分离到不同的子空间。


组件四:AdaLN-Zero - 条件注入的核心创新

AdaLN-Zero是DiT最重要的创新,解决了"如何将时间步tt和类别cc注入Transformer"这一核心问题。

扩散模型的条件注入难题

扩散模型需要接收两类信息:

  1. 内容信息:噪声图像 xtx_t
  2. 条件信息
    • 时间步 tt:当前处于扩散过程的哪个阶段(关键!)
    • 类别标签 cc:生成什么类别的图像

传统方法:

  • 加法注入x+f(t,c)\mathbf{x} + f(t, c) —— 太简单,条件易被覆盖
  • 拼接注入[x;f(t,c)][\mathbf{x}; f(t, c)] —— 增加序列长度,计算量增大
  • Cross-Attention:将条件作为Key/Value —— 复杂度高 O(N×M)O(N \times M)

DiT提出了AdaLN(Adaptive Layer Normalization),一种高效且表达力强的方案。

Adaptive Layer Normalization的数学原理

标准的Layer Normalization:

LN(x)=γxμσ+β\text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta

其中 γ,β\gamma, \beta 是固定的可学习参数。

AdaLN的核心思想:让 γ,β\gamma, \beta 依赖于条件信息

γ(c),β(c)=MLP(c)\gamma(\mathbf{c}), \beta(\mathbf{c}) = \text{MLP}(\mathbf{c})
AdaLN(x,c)=γ(c)xμσ+β(c)\text{AdaLN}(\mathbf{x}, \mathbf{c}) = \gamma(\mathbf{c}) \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta(\mathbf{c})

其中 c=f(t,c)\mathbf{c} = f(t, c) 是时间步和类别的嵌入向量。

直观理解:调制(Modulation)

AdaLN本质上是用条件信息调制特征的分布

  • γ(c)\gamma(\mathbf{c}):控制特征的尺度(scale)
  • β(c)\beta(\mathbf{c}):控制特征的偏移(shift)

不同的条件 c\mathbf{c} 产生不同的 γ,β\gamma, \beta,从而引导网络产生不同的输出。

类比:想象一个收音机,条件信息是调频旋钮,γ,β\gamma, \beta 是调制信号,特征 x\mathbf{x} 是被调制的载波。

AdaLN-Zero:Zero Initialization的关键改进

DiT在AdaLN基础上加入了Zero Initialization,这是训练稳定性的核心。

标准的DiT Block结构:

h1=x+α1(c)Attention(AdaLN(x,c))\mathbf{h}_1 = \mathbf{x} + \alpha_1(\mathbf{c}) \odot \text{Attention}(\text{AdaLN}(\mathbf{x}, \mathbf{c}))
h2=h1+α2(c)MLP(AdaLN(h1,c))\mathbf{h}_2 = \mathbf{h}_1 + \alpha_2(\mathbf{c}) \odot \text{MLP}(\text{AdaLN}(\mathbf{h}_1, \mathbf{c}))

其中 α1,α2\alpha_1, \alpha_2门控参数,也由条件生成:

[γ1,β1,α1,γ2,β2,α2]=MLPmodulation(c)[\gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2] = \text{MLP}_{\text{modulation}}(\mathbf{c})

Zero Initialization的定义

MLPmodulation=W2SiLU(W1c+b1)+b2\text{MLP}_{\text{modulation}} = W_2 \cdot \text{SiLU}(W_1 \mathbf{c} + b_1) + b_2

初始化时:

W2=0,b2=0W_2 = \mathbf{0}, \quad b_2 = \mathbf{0}

这保证了训练初始时:

γ1=γ2=1,β1=β2=0,α1=α2=0\gamma_1 = \gamma_2 = 1, \quad \beta_1 = \beta_2 = 0, \quad \alpha_1 = \alpha_2 = 0

因此:

h1=x+0Attention()=x\mathbf{h}_1 = \mathbf{x} + 0 \cdot \text{Attention}(\cdots) = \mathbf{x}
h2=x+0MLP()=x\mathbf{h}_2 = \mathbf{x} + 0 \cdot \text{MLP}(\cdots) = \mathbf{x}

整个网络初始时是恒等映射f(x)=xf(\mathbf{x}) = \mathbf{x}

为什么Zero Initialization如此重要?

1. 梯度流动的畅通性

深度网络训练的核心挑战是梯度消失/爆炸

在恒等映射下,梯度可以无损地反向传播:

Lx=Lh2h2x=Lh2I\frac{\partial \mathcal{L}}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot \frac{\partial \mathbf{h}_2}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot I

其中 II 是单位矩阵,梯度直接传递,不会衰减。

2. 从简单到复杂的学习路径

随着训练进行,门控参数 α1,α2\alpha_1, \alpha_2 从0逐渐增大,模型逐步学习利用注意力和MLP的输出。

这是一种curriculum learning(课程学习)策略:先学简单的(恒等映射),再学复杂的(注意力模式)。

3. 残差连接的极致体现

残差连接(ResNet)的核心公式:

h=x+F(x)\mathbf{h} = \mathbf{x} + F(\mathbf{x})

F(x)=0F(\mathbf{x}) = 0 时,网络退化为恒等映射,保证了至少不会比浅层网络差。

AdaLN-Zero通过zero initialization,强制初始时F(x)=0F(\mathbf{x}) = 0,这是残差思想的最彻底实践。

AdaLN vs 其他条件注入方式

方法计算复杂度表达能力训练稳定性
加法注入O(1)O(1)
拼接注入O(N)O(N)
Cross-AttentionO(NM)O(N \cdot M)
AdaLN-ZeroO(1)O(1)

AdaLN-Zero的优势:

  • 零额外计算:不增加序列长度,不增加注意力计算
  • 强表达力:通过调制归一化参数,影响每一层的特征分布
  • 训练稳定:zero initialization保证梯度流畅通

第三部分:DiT的推理过程

推理就是从纯噪声逐步去噪,生成清晰图像

DDPM采样:严格的概率过程

DDPM(Denoising Diffusion Probabilistic Models)是最原始的采样算法,严格遵循扩散模型的概率推导。

单步去噪公式

xt1=1αt(xt1αt1αˉtεθ(xt,t,c))+σtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \varepsilon_\theta(x_t, t, c) \right) + \sigma_t z

其中:

  • εθ(xt,t,c)\varepsilon_\theta(x_t, t, c) 是DiT预测的噪声
  • zN(0,I)z \sim \mathcal{N}(0, I) 是新采样的随机噪声
  • σt=β~t\sigma_t = \sqrt{\tilde{\beta}_t} 是后验方差

完整流程

  1. 初始化 xTN(0,I)x_T \sim \mathcal{N}(0, I)(纯高斯噪声)
  2. 对于 t=T,T1,,1t = T, T-1, \ldots, 1
    • 前向传播DiT:ε^=εθ(xt,t,c)\hat{\varepsilon} = \varepsilon_\theta(x_t, t, c)
    • 计算均值:μt=1αt(xt1αt1αˉtε^)\mu_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\hat{\varepsilon}\right)
    • 采样噪声:zN(0,I)z \sim \mathcal{N}(0, I)
    • 更新:xt1=μt+σtzx_{t-1} = \mu_t + \sigma_t z
  3. 返回 x0x_0

特点

  • :需要1000步,每步都要前向传播DiT
  • 质量高:每步添加适量随机性,生成多样性好
  • 理论清晰:严格遵循后验分布 q(xt1xt,x0)q(x_{t-1}|x_t, x_0)

DDIM采样:确定性加速

DDIM(Denoising Diffusion Implicit Models)通过确定性过程实现加速。

核心思想:不采样新噪声,而是走确定性的"直线路径"。

DDIM公式

xt1=αˉt1xt1αˉtε^αˉtpredicted x0+1αˉt1ε^x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\frac{x_t - \sqrt{1-\bar{\alpha}_t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}_t}}}_{\text{predicted } x_0} + \sqrt{1-\bar{\alpha}_{t-1}} \hat{\varepsilon}

这个公式的直观理解:

  1. 用当前 xtx_t 和预测噪声 ε^\hat{\varepsilon},估计干净图像:
x^0=xt1αˉtε^αˉt\hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}_t}}
  1. x^0\hat{x}_0ε^\hat{\varepsilon},重新组合成 xt1x_{t-1}
xt1=αˉt1x^0+1αˉt1ε^x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \hat{x}_0 + \sqrt{1-\bar{\alpha}_{t-1}} \hat{\varepsilon}

关键区别

  • DDPM:每步采样新噪声 zz,引入随机性
  • DDIM:重复使用同一个噪声估计 ε^\hat{\varepsilon},是确定性过程

加速原理:由于是确定性的,可以跳步采样。例如:

  • DDPM:1000999998101000 \to 999 \to 998 \to \cdots \to 1 \to 0(1000步)
  • DDIM:10009509005001000 \to 950 \to 900 \to \cdots \to 50 \to 0(20步)

特点

  • :50步达到DDPM 1000步的效果,速度提升20倍
  • 确定性:相同初始噪声和条件,生成完全相同的图像
  • 质量略降:FID略高于DDPM,但肉眼难以区分

Classifier-Free Guidance:提升条件遵循度

CFG(Classifier-Free Guidance)是提升生成质量的关键技术。

问题:标准条件生成可能"不够听话"。指定生成"猫",模型可能生成模糊的猫,或混合其他动物特征。

解决方案:训练时同时学习条件生成和无条件生成,推理时"放大"条件影响。

CFG训练

训练时,以概率 p=0.1p=0.1 将类别标签置空:

c={概率 0.1c概率 0.9c' = \begin{cases} \emptyset & \text{概率 } 0.1 \\ c & \text{概率 } 0.9 \end{cases}

其中 \emptyset 用特殊token表示(如类别ID=1000)。

这样模型学会了两种模式:

  • εθ(xt,t,c)\varepsilon_\theta(x_t, t, c):给定类别cc的条件生成
  • εθ(xt,t,)\varepsilon_\theta(x_t, t, \emptyset):无条件生成
CFG推理

推理时,将两者线性组合:

ε~=εθ(xt,t,)+w(εθ(xt,t,c)εθ(xt,t,))\tilde{\varepsilon} = \varepsilon_\theta(x_t, t, \emptyset) + w \cdot \left(\varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset)\right)

其中 ww 是guidance scale(通常w=7.5w=7.5)。

数学直观

  • εθ(xt,t,c)εθ(xt,t,)\varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset):条件相对于无条件的"差异方向"
  • w>1w > 1:沿着这个方向走得更远,放大条件影响
  • w=1w = 1:标准条件生成
  • w=0w = 0:无条件生成

效果

Guidance scale ww类别一致性图像多样性图像质量
1.0一般
3.0-5.0
7.5 (推荐)中低最好
15.0+过高过饱和、失真

代价:CFG需要推理两次(条件+无条件),推理时间翻倍。但效果提升显著,是工业标准。

完整推理流程

结合DDIM和CFG:

输入

  • 类别 cc(如"猫"的ID=281)
  • 采样步数 S=50S=50
  • Guidance scale w=7.5w=7.5

算法

1. 确定时间步序列:τ = [1000, 950, 900, ..., 50, 0](均匀采样S步)
2. 初始化:x ← N(0, I)
3. For t in τ[:-1]:
     t_next ← τ中t的下一个时间步

     # 条件预测
     ε_cond ← DiT(x, t, c)

     # 无条件预测
     ε_uncond ← DiT(x, t, ∅)

     # CFG组合
     ε̂ ← ε_uncond + w * (ε_cond - ε_uncond)

     # 估计x₀
     x̂₀ ← (x - √(1-ᾱₜ)·ε̂) / √ᾱₜ

     # DDIM更新
     x ← √ᾱₜ_ₙₑₓₜ · x̂₀ + √(1-ᾱₜ_ₙₑₓₜ) · ε̂

4. Return x

时间成本(DiT-XL,A100 GPU):

  • DDPM 1000步:约60秒/图
  • DDIM 50步 + CFG:约6秒/图

第四部分:DiT的训练过程

训练目标

DiT的训练是简单的噪声预测任务

L=Et,x0,ε,c[εεθ(xt,t,c)2]\mathcal{L} = \mathbb{E}_{t, x_0, \varepsilon, c}\left[\|\varepsilon - \varepsilon_\theta(x_t, t, c)\|^2\right]

其中 xt=αˉtx0+1αˉtεx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon

训练算法

单个训练step

1. 采样一批数据:(x₀, c) ~ 数据集
2. 采样时间步:t ~ Uniform(1, T)
3. 采样噪声:ε ~ N(0, I)
4. 前向加噪:xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε
5. 预测噪声:ε̂ = DiT(xₜ, t, c)
6. 计算损失:L = ‖ε̂ - ε‖²
7. 反向传播:更新参数θ

关键训练细节

1. 噪声调度(Noise Schedule)

βt\beta_t 的设计影响训练效果。DiT使用线性调度

βt=βmin+t1T1(βmaxβmin)\beta_t = \beta_{\min} + \frac{t-1}{T-1}(\beta_{\max} - \beta_{\min})

典型值:βmin=0.0001,βmax=0.02,T=1000\beta_{\min} = 0.0001, \beta_{\max} = 0.02, T=1000

这意味着:

  • 早期(tt小):βt\beta_t很小,加噪缓慢,图像几乎不变
  • 后期(tt大):βt\beta_t接近0.02,加噪快速,图像迅速变成纯噪声
2. Classifier-Free Guidance训练

如前所述,训练时10%概率drop类别:

c={null tokenp=0.1cp=0.9c' = \begin{cases} \text{null token} & p=0.1 \\ c & p=0.9 \end{cases}

这让模型同时学会两种生成模式。

3. 学习率调度

DiT使用warmup + cosine decay

ηt=ηmin+12(ηmaxηmin)(1+cos(πttwTtw))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\pi \frac{t - t_w}{T - t_w}\right)\right)

其中:

  • tw=10000t_w = 10000:warmup步数
  • ηmax=104\eta_{\max} = 10^{-4}:峰值学习率
  • ηmin=105\eta_{\min} = 10^{-5}:最小学习率

前10000步线性增长到 ηmax\eta_{\max},之后按余弦函数衰减。

原理:warmup避免初期梯度过大导致发散;cosine decay比step decay更平滑。

4. EMA(Exponential Moving Average)

维护参数的指数移动平均:

θEMAμθEMA+(1μ)θ\theta_{\text{EMA}} \leftarrow \mu \theta_{\text{EMA}} + (1-\mu)\theta

其中 μ=0.9999\mu = 0.9999

推理时使用 θEMA\theta_{\text{EMA}} 而非 θ\theta

原理:EMA相当于对训练轨迹上的多个checkpoint做平滑,减少单个模型的抖动,提升生成质量和稳定性。

5. 混合精度训练

使用FP16计算,同时维护FP32主权重:

  • 前向传播、梯度计算:FP16
  • 参数更新:FP32

收益

  • 训练速度提升1.5-2倍
  • 显存占用减半
  • 精度损失可忽略

训练规模与成本

DiT-XL的训练配置:

项目数值
参数量675M
数据集ImageNet(130万图像,1000类)
Batch size256(8卡 × 32/卡)
训练步数7M steps
训练时长约1个月(8×A100 80GB)
总计算量约10M GPU-hours
FID(256×256)2.27

Scaling Law:DiT的惊人发现

DiT首次在图像生成模型中展现出清晰的scaling law:

模型参数量深度宽度FID ↓
DiT-S33M12层3849.62
DiT-B130M12层7685.31
DiT-L458M24层10243.04
DiT-XL675M28层11522.27

关键观察

  1. 性能持续提升:从DiT-S到DiT-XL,FID持续下降,没有饱和迹象
  2. 对数线性关系:FID与log(参数量)近似线性关系
  3. 类似LLM:这与语言模型的scaling law特性一致

意义

  • 更大的模型 → 更好的生成质量(确定性规律)
  • 为投资更大模型提供了理论依据
  • 预示着10B+参数的扩散模型可能带来质的飞跃

总结:DiT的意义与启示

核心贡献

1. 架构统一

证明了Transformer可以作为扩散模型的通用backbone,图像生成不再需要特定领域的架构设计。

2. AdaLN-Zero

提出了优雅的条件注入机制,在零额外计算成本下实现强大的表达能力和训练稳定性。

3. Scaling Law

首次在图像生成中展现scaling特性,为"训练更大模型"提供了理论支持。

4. 性能突破

FID 2.27(256×256 ImageNet),超越所有基于卷积的方法。

DiT的局限

1. 计算复杂度:自注意力是 O(N2)O(N^2),分辨率越高越慢

2. 推理时间:即使DDIM,仍需50步,比单次前向慢50倍

3. 数据需求:需要大规模数据(百万级)才能充分发挥scaling优势

4. 条件类型:目前主要支持类别标签,对长文本支持有限

未来方向

1. 更高效的注意力:Sparse Attention、Linear Attention、Flash Attention

2. 更快的采样:Consistency Models(一步生成)、Latent Diffusion(低维空间扩散)

3. 更大的模型:DiT-XXL(10B参数)在更大数据集上训练

4. 多模态扩展:文本到图像、视频生成、3D生成

关键启示

  1. Transformer的通用性:不仅NLP,CV也适用
  2. Scaling的威力:更大的模型带来更好的效果
  3. 架构细节的重要性:AdaLN-Zero这样的创新带来质的提升
  4. 条件注入的本质:如何注入比注入什么更重要

DiT代表了扩散模型从CNN到Transformer的范式转变,这与NLP从RNN到Transformer的转变如出一辙。

Transformer + Diffusion = 图像生成的未来。


参考文献

  1. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
  2. Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
  3. Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
  4. Ho, J., & Salimans, T. (2022). Classifier-Free Diffusion Guidance. NeurIPS Workshop 2021.
  5. Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.