代码层面上学习diffusion(DDPM)模型

728 阅读8分钟

前言

借助 Python 库 denoising-diffusion-pytorch,通过调试与阅读源码来探究 DDPM(Denoising Diffusion Probabilistic Models)模型的结构与运作机理。本文留下阅读代码时的痕迹(后记:读完代码太不容易了。感谢无私奉献的博客、论文以及 ChatGPT)。

可以看看这个英文文章,对 diffusion 有较为全面的简要的基础叙述,最后也给出后续的改进工作。

DDPM 需要较长扩散步数才能得到较好效果。后来出现了 DDIM(Denoising Diffusion Implicit Models),不再限制扩散过程必须是一个马尔卡夫链,可以采用更小的采样步数来加速生成过程。

zhuanlan.zhihu.com/p/565698027 DDIM

zhuanlan.zhihu.com/p/563661713 DDPM

littlenyima.github.io/posts/27-sd… SDXL 与 SD 的区别

本文默认读者对 DDPM 的推导有概念。建议看看原论文。

本文内容较多结构较乱。善用 Ctrl + F。

本文显然并不完善甚至可能有错。也欢迎讨论。

本文主要包含以下内容:

  • U-Net 及其 forward 过程
  • Diffusion 采样流程
  • Diffusion 训练过程
  • 额外说明。一些提升训练效果和模型效果的 trick。有的还蛮重要

模型定义(Unet

使用 denoising_diffusion_pytorch 库实例化一个扩散模型,要使用以下代码:

import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion

model = Unet(dim=64, dim_mults=(1, 2, 4, 8), flash_attn=True)

diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000)  # number of steps

可见,先实例化 Unet,然后再装入 GaussianDiffusionUnet 实例是这个模型的核心。

Unet.__init__()

从两个输入参数来看是怎样定义的 U-Net 维度。

dims:对于 dim=64, dim_mults=(1, 2, 4, 8),会有 dims = [64, 64, 128, 256, 512]

in_outin_out = list(zip(dims[:-1], dims[1:]))。结果上会有 in_out = [(64, 64), (64, 128), (128, 256), (256, 512)]

Unet.forward()

Unet 层实例的 forward 代码过程如下。输入本时间步的图像 x 和时间步 t 后:

若没有明确说明,默认是对 x 张量进行操作;

其中指明维度具体数和模型层数量的地方仅供参考,可能不同配置和同网络不同实例会有不同的维度数和层数量;

加粗部分是 U-Net 的残差结构。

  • Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  • 保存残差 r
  • self.time_mlp,将 timestep 值处理为 embedding 形式的时间嵌入
    • SinusoidalPosEmb,正弦位置编码
    • nn.Linear(64, 256)
    • nn.GELU
    • nn.Linear(256, 256)
  • 开始使用 self.downs 降维。self.downs 是一个 ModuleList,包含四组 block1 block2 attn downsample
    • block1ResnetBlock)。卷积 x -> x 与时间嵌入混合 -> 卷积 x。最终获得大小和维度不变的 x
      • ResnetBlock 的详细过程请看后文
    • h.append(x) 保存处理结果
    • block2ResnetBlock)。与 block1 结构一模一样
    • attnLinearAttention),且进行残差
      • 注意 self.downs[-1]attnAttention 实例。关于两种 attention 实例的区别请看后文
    • h.append(x) 保存处理结果
    • downsample
      • Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2)。对 x y 轴元素进行两两切分,使得新 h 和新 w 会指向一个 2x2 的区块。总之是一种另类的降采样手法,把像素分散到 channel 维
      • Conv2d(256, 64, kernel_size=1, stride=1),用 1x1 卷积来降 channel 数
      • (注意 self.downs[-1]downsample 仅有 Conv2d(256, 512, kernel_size=3, stride=1, padding=1),没有改变维度 rearrange 的步骤)
  • self.mid_block1ResnetBlock),卷积以及时间嵌入混合
  • self.mid_attnAttention)且进行残差
  • self.mid_block2ResnetBlock),卷积以及时间嵌入混合
  • 开始使用 self.ups 升维。self.ups 是一个 ModuleList,包含四组 block1 block2 attn upsample
    • xh.pop() 拼接。现在的 x 维度是 [b, c, h, w],拼接发生在 c 维度
    • block1ResnetBlock),卷积以及时间嵌入混合
      • 上一步的拼接导致 x 的 c 维度从 512 变为 768。经过这一层后变回 512。后同
    • xh.pop() 拼接。拼接发生在 c 维度
    • block2ResnetBlock),卷积以及时间嵌入混合
    • attnself.downs[0] 时是 Attention,否则 LinearAttention),且进行残差
    • upsample
      • nn.Upsample(scale_factor=2.0, mode='nearest')
      • Conv2d(512, 256, kernel_size=3, stride=1, padding=1),用 3x3 卷积来降 channel 数
      • (注意 self.downs[-1]downsample 仅有 Conv2d(64, 64, kernel_size=3, stride=1, padding=1),没有 nn.Upsample 的步骤)
  • 此时 x 维度为 [4, 64, 128, 128]。已经还原为 self.init_conv 升维后、self.downs 降维前的状态
  • x 与残差 r 在通道维度进行拼接
  • self.final_res_blockResnetBlock),卷积以及时间嵌入混合
  • Conv2d(64, 3, kernel_size=1, stride=1),降 channel 数。x 的外形已经是通常的 image

ResnetBlock

ResnetBlock 是 U-Net 的重要组成部分。涉及到的流程:卷积 x -> x 与时间嵌入混合 -> 卷积 x

关于 ResnetBlock,输入 x 和时间嵌入 t 后:

  • 暂存残差
  • 时间嵌入经过 self.mlpnn.SiLU -> nn.Linear(256, 128)
  • 时间嵌入使用 .chunk 在维度上等分为两部分。至于原因请继续往下看
  • x 和时间嵌入输入到 self.block1Block
    • x 经过 Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    • x 经过 RMSNorm
    • 时间嵌入拆分为 scale shift,然后 x = x * (scale + 1) + shift
    • x 经过 nn.SiLU
    • x 经过 nn.Dropout
  • 仅将 x 输入到 self.block2Block
    • x 经过 Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    • x 经过 RMSNorm
    • x 经过 nn.SiLU
    • x 经过 nn.Dropout
  • x 经过 self.res_conv 映射到所需维度。若维度不需要修改,本层变为 Identity
  • 进行残差

关于时间嵌入的获得,代码实现有点绕,东一点西一点写得到处都是。汇总一下应该是一下这个路线,其中线性层维度仅供参考:

SinusoidalPosEmb -> nn.Linear(64, 256) -> nn.GELU -> nn.Linear(256, 256) -> nn.SiLU -> nn.Linear(256, 128) -> 维度等分为 scaleshift

很奇怪的是,整个模型几乎只使用 SiLU 激活,这里却唐突出现了 GELU。原因未知。

LinearAttention

LinearAttention 层实现了不同于 scaled_dot_product_attention 的一种更高效的 Attention 机制,适合处理超长序列(例如第一个降维层的序列长度长达 128 * 128)。

关于 LinearAttention,输入 x 后:

  • x 经过 RMSNorm
  • x 经过 Conv2d(64, 384, kernel_size=1, stride=1, bias=False),用 .chunk() 分离出 qkv
  • 借助 map() 分别对 q k v 转换维度:从 [b, c, x, y] 到 [b, h, c’, (x, y)],有 4 个头
  • self.mem_kv 是维度为 [2, h, c', n] 的可学习参数。此时会拆分为 mk mv
  • dim=-1 下拼接 k mk,还有 v mv,获得新的 k v
  • q = q.softmax(dim=-2)。即对 channel 维度施加 softmax
  • k = k.softmax(dim=-1)。即对 n 维度施加 softmax
  • q 乘上 dim_head ** 0.5
  • context = k @ v.transpose(2,3)
  • out = context.transpose(2,3) @ q
  • out 转换维度:从 [b, h, c', (x, y)] 到 [b, (h, c'), x, y]
  • Conv2d(128, 64, kernel_size=1, stride=1)
  • RMSNorm,然后返回值

Attention

可能是最后一个降维层的序列长度只有 16*16 已经足够短了,代码就选用了这个以 scaled_dot_product_attention 为核心的 Attention。

关于 Attention,输入 x 后:

  • x 经过 RMSNorm
  • x 经过 Conv2d(64, 384, kernel_size=1, stride=1, bias=False),用 .chunk() 分离出 qkv
  • 借助 map() 分别对 q k v 转换维度:从 [b, c, x, y] 到 [b, h, (x, y), c’]。这里与 LinearAttention 不同,c' 维度在最后
  • self.mem_kv 是维度为 [2, h, n, c'] 的可学习参数。此时会拆分为 mk mv
  • dim=-2 下拼接 k mk,还有 v mv,获得新的 k v
  • 使用 scaled_dot_product_attention 计算 q k v 的 attention
  • out 转换维度:从 [b, h, (x, y), c'] 到 [b, (h, c'), x, y]
  • Conv2d(128, 256, kernel_size=1, stride=1),然后返回值

采样过程(GaussianDiffusion.sample()

使用 DDPM 的话,进入到 GaussianDiffusion.sample() 会调用 GaussianDiffusion.p_sample_loop() 进行采样。

这里涉及到很多数学推导和优化 trick。部分 trick 放在了最后的 “额外说明” 一节进行补充。

p_sample_loop()

是 DDPM 采样主循环。不断生成各个步下的 img,完成整个采样过程。

p_sample_loop() 流程如下。

  1. 创建随机噪声(img = torch.randn(shape)),作为图像生成的起点
  2. 反向遍历 timesteps(从最大到最小)调用 self.p_sample(),不断更新 img
    1. 调用 self.p_sample(),获得 pred_imgx_start
  3. 逆向执行图像正则化(分布从 [-1, 1] 变为 [0, 1]),返回最后一步输出的 img

self.p_sample()

通过预测的均值和方差,获得本步预测的图像。

self.p_sample() 流程如下。

  1. 调用 self.p_mean_variance(),使用指定 x 和时间步,从模型推理获得:
    1. μ^\hat{\mu},预测均值 model_mean
    2. log(β^)\log(\hat{\beta}),后验方差的对数 model_log_variance
    3. x0x_0,模型预测的 x_start
  2. 从标准正态分布采样出一个 ϵ\epsilon noise
    1. 时间步为 0 时的噪声为全 0
  3. 通过 μ~+exp(0.5log(σ^2))ϵ\tilde{\mu}+\exp(0.5\log(\hat{\sigma}^2))\cdot \epsilon 算出 xt1x_{t-1} pred_img
    1. 公式等价于 μ~+σ~2ϵ\tilde{\mu}+\sqrt{\tilde{\sigma}^2}\cdot \epsilon
  4. 返回 pred_imgx_start

self.p_mean_variance()

通过模型预测出 x_start,得到本 step 的噪声均值 μ^\hat{\mu} 和方差 β^\hat{\beta}

self.p_mean_variance() 流程如下。

  1. 调用 self.model_predictions(),获得 x^0\hat{x}_0 x_start
    1. 也获得了 pred_noise,但没用上
    2. x_start 进行 .clamp_(-1., 1.)
  2. 调用 self.q_posterior,获得 xt1x_{t-1} 的后验均值、后验方差
    1. 使用 μ~t=αt1βt1αtx^0+αt(1αt1)1αtxt\tilde{\mu}_t=\frac{\sqrt{\overline{\alpha}_{t-1}}\beta_t}{1-\overline{\alpha}_t}\hat{x}_0+\frac{\sqrt{\alpha_t}(1-\overline{\alpha}_{t-1})}{1-\overline{\alpha}_t}x_t 计算后验均值 model_mean
    2. σ~t2=1αt11αtβt\tilde{\sigma}^2_t=\frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t}\beta_t 取值得后验方差 posterior_mean
    3. posterior_mean.log() 作为 posterior_log_variance_clipped
  3. 返回模型预测的均值 model_mean、后验方差 posterior_variance、后验方差的对数 posterior_log_variancex^0\hat{x}_0 x_start

self.model_predictions()

模型预测出 x^0\hat{x}_0。结合公式可以算出模型所预测的噪声均值 μ^\hat{\mu}

self.model_predictions() 流程如下。

  1. 输入 x 和时间步 tself.modelUnet)进行模型推理,获得 shape 完全一样的输出 model_output
  2. 默认启用 v-prediction 推导技巧。model_output 预测的是 velocity
    1. 调用 self.predict_start_from_v() 通过 x^0=αtxt1αtv\hat{x}_0=\sqrt{\overline{\alpha}_t}x_t-\sqrt{1-\overline{\alpha}_t}v 获得预估的 x0x_0 x_start
    2. 调用 self.predict_noise_from_start 通过 ϵ^=xtαtx^01αt\hat{\epsilon}=\frac{x_t-\sqrt{\overline{\alpha}_t}\hat{x}_0}{\sqrt{1-\overline{\alpha}_t}} 获得噪声估计 pred_noise
  3. 返回 pred_noisex_start

训练过程(GaussianDiffusion.forward()

.forward() 输入维度为 [batch, channel, hight, width] 的图像。流程如下:

  1. 产生 batch 个随机整数,用来指示时间步数 t
  2. 对输入图像应用 normalize。将 [0, 1] 的分布变为 [-1, 1]
  3. 调用 self.p_losses() 获得损失

self.p_losses()

用输入的一个 batch 的 x0x_0tt 进行模型推理,并获得 loss。

  1. 创建噪声 ϵ\epsilon
    1. offset_noise_strength 不为 0,则使用 offset noise 技巧
  2. 调用 self.q_sample,从 x0x_0ϵ\epsilon 和步数 tt 获得加噪图像 xtx_t
  3. xtx_ttt 输入到 self.modelUnet),获得模型输出 model_out
  4. 默认启用 v-prediction 推导技巧。model_output 预测的是 velocity。所以调用 self.predict_v() 通过 v=αtϵ1αtx0v=\sqrt{\overline{\alpha}_t}\epsilon-\sqrt{1-\overline{\alpha}_t}x_0 获得训练目标 target
  5. loss = F.mse_loss(model_output, target)
    1. 传入参数 reduction='none',以便于后面对不同时间步 tt 下的 loss 施加不同权重
  6. loss 施加权重。由于不同图片有不同的时间步 tt,施加的权重也有所不同
    1. 关于这一点,详见最后的 “额外说明” 一节
  7. 返回 loss

self.q_sample()

从给定的 x0x_0ϵ\epsilon 和步数 tt 获得加噪图像 xtx_t

  1. 若启用了 immiscible 训练技巧,对 noise 进行额外处理
    1. 关于 immiscible 详见最后 “额外说明” 一节
  2. 借助 xt=αtx0+1αtϵx_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon 采样出第 t 步的加噪图像 xtx_t

额外说明

sigmoid beta schedule,对 βi\beta_i 取值的另类方案

beta_schedule,控制不同时间步加入的噪声量。扩散模型加噪声的过程可以用 xt=1βtxt1+βtϵx_t=\sqrt{1-\beta_t}\cdot x_{t-1}+\sqrt{\beta_t}\cdot\epsilon 表示,不同的 beta scheduler 会决定 βt\beta_ttt 的值变化规则。具体来说,beta_schedule 最终会会使得 self 获得一个长度为时间步数 timesteps 的、值从 0 到 1 递增的一维 tensor。

denoising-diffusion-pytorch 代码默认使用 sigmoid_beta_schedule() 获得 betas。该函数需要输入时间步数 timesteps、起始 x 轴值 start、末尾 x 轴值 end 和缩放因子 tau

代码里说这个 schedule 方法来源于 arxiv.org/abs/2212.11… (《Scalable Adaptive Computation for Iterative Generation (2023)》),对大于 64x64 的图像效果良好。

但我搜索了一圈发现几乎没有使用这个方案的项目。例如 diffusers 库里预设的 scheduler 有 linear(默认)、scaled_linear 和用 torch.linspace().sigmoid() 的方案,就是没有本节介绍这个。

所以这一节权当多点了解。实际使用要用啥还得看看现有 SOTA 模型的实现(TODO 看看 SD3 用的 FlowMatchEulerDiscreteScheduler 到底是个啥)。

首先定义 t=[0,1,,timesteps+1]timestepst=\frac{[0,1,\dots,\mathrm{timesteps}+1]}{\mathrm{timesteps}},即 0 到 1 均匀分布的一维向量。

然后计算 αt\overline{\alpha}_t

αt=t=0t0βt=vendσ(t(endstart)+startτ)vendvstart=σ(endτ)σ(t(endstart)+startτ)σ(endτ)σ(startτ)\begin{split} \overline{\alpha}_t=\sum_{t=0}^{t_0} \beta_t&= \frac {v_{\mathrm{end}} - \sigma(\frac{t(\mathrm{end}-\mathrm{start})+\mathrm{start}}{\tau})} {v_{\mathrm{end}}-v_{\mathrm{start}}}\\ &= \frac {\sigma(\frac{\mathrm{end}}{\tau}) - \sigma(\frac{t(\mathrm{end}-\mathrm{start})+\mathrm{start}}{\tau})} {\sigma(\frac{\mathrm{end}}{\tau})-\sigma(\frac{\mathrm{start}}{\tau})} \end{split}

其中 σ\sigma 是 sigmoid 函数,τ\tau 是控制斜率的标量(默认为 1)。

理解 σ(t(endstart)+startτ)\sigma(\frac{t(\mathrm{end}-\mathrm{start})+\mathrm{start}}{\tau}):假设 start=-3,end=3,τ=1\tau=1,这是把 [-3, 3] 的数值用 sigmoid 映射。

可见,可以说 αt\overline{\alpha}_t 是在 [-3, 3] 上均匀采样的 sigmoid 函数进行拉伸与翻转,使其为最大值为 1 最小值为 0 的递减数组

为了得到 βi\beta_i,这样计算:

βi=1αtαt1\beta_i=1-\frac{\overline{\alpha}_{t}}{\overline{\alpha}_{t-1}}

由此获得的 βi\beta_i 会长成下图这样。可见在最后几十步会突然加入相当大比例的噪声。

beta scheduler 每步施加的噪声

v-prediction:模型预测 xxϵ\epsilon 的聚合

本节参考:

库代码默认启用 v-prediction 进行采样。模型预测的对象不再是原论文所述的 noise,而是 velocity。该方法在论文《Progressive Distillation for Fast Sampling of Diffusion Models (2022)》中被提出,用来解决模型蒸馏时纯噪声预测在 t[1]t[-1] 步骤不稳定的问题。

模型的预测对象是 ϵ\epsilon 噪声的话,该采样过程也被称为 ε-prediction;

若预测对象直接是图像 xx,则称为 x-prediction。

回忆 DDPM 的加噪声等式。对于输入的原始图像 x0x_0,通过以下式子获得 tt 时刻的加噪图像 xtx_t

xt=αtx0+1αtϵx_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon

可见,αt\sqrt{\overline{\alpha}_t}1αt\sqrt{1-\overline{\alpha}_t} 的平方和为 1。这就能用 sin(φ)+cos(φ)=1\sin(\varphi)+\cos(\varphi)=1 类比。换句话说,加噪公式可以视为:

xt=cos(φ)x0+sin(φ)ϵx_t=\cos(\varphi)x_0+\sin(\varphi)\epsilon

对式子右侧求导得到 φ\varphi 的导数,我们记为 velocity vv

v=(cos(φ)x0+sin(φ)ϵ)=cos(φ)ϵsin(φ)x0\begin{split} v &=\Big(\cos(\varphi)x_0+\sin(\varphi)\epsilon\Big)'\\ &=\cos(\varphi)\epsilon-\sin(\varphi)x_0\\ \end{split}

将三角函数换回原来的样子,可以得到 vv 的表达式。而这就是模型的预测目标:

v=αtϵ1αtx0v=\sqrt{\overline{\alpha}_t}\epsilon-\sqrt{1-\overline{\alpha}_t}x_0

可见,v-prediction 相当于加权聚合了 x-prediction 与 ε-prediction

好,现在还需要获得 x^0(αt,xt,v)\hat{x}_0(\sqrt{\alpha_t},x_t,v)。联立 xt=αtx0+1αtϵx_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilonv=αtϵ1αtx0v=\sqrt{\overline{\alpha}_t}\epsilon-\sqrt{1-\overline{\alpha}_t}x_0 可得:

x^0=αtxt1αtv\hat{x}_0=\sqrt{\overline{\alpha}_t}x_t-\sqrt{1-\overline{\alpha}_t}v

有了 xtx_tx^0\hat{x}_0,就可以计算以下式子来获得噪声估计 ϵ^\hat{\epsilon}

ϵ^=xtαtx^01αt\hat{\epsilon}=\frac{x_t-\sqrt{\overline{\alpha}_t}\hat{x}_0}{\sqrt{1-\overline{\alpha}_t}}

根据 DDPM 原论文所述,已知 xtx_tx0x_0 就可以这样计算 xt1x_{t-1} 分布的均值 μ\mu

μ~t=αt1βt1αtx^0+αt(1αt1)1αtxt\tilde{\mu}_t=\frac{\sqrt{\overline{\alpha}_{t-1}}\beta_t}{1-\overline{\alpha}_t}\hat{x}_0+\frac{\sqrt {\alpha_t}(1-\overline{\alpha}_{t-1})}{1-\overline{\alpha}_t}x_t

而方差 σ2\sigma^2

σ~t2=1αt11αtβt\tilde{\sigma}^2_t=\frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t}\beta_t

现在也可以用重参数化方法获得噪声 ϵ\epsilon

self_condition:利用 x^0\hat{x}_0 进行残差连接

一个叫做 self-conditioning 的 trick,该技巧提出于论文 《Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning (2022)》,据实验证明可以提高扩散模型的性能。

通常模型的推理 需要输入本时间步的图像 x 和时间步 t,然后推理出 x^0\hat{x}_0,进而计算出均值方差和下一步的 x

x_start = model(x, t)

若使用 self-conditioning 技巧,则模型会额外输入上一时间步推理出的 x^0\hat{x}_0

x_start = model(x, t, x_start)

我看 denoising-diffusion-pytorch 库代码实现在输入 x_start 后,会将之与 x 在 channel 维度拼接。之后的 U-Net 推理不变。

换个角度想,这就是一种另类的残差连接。

offset noise:让模型学会打破均值的能力

本节参考:

Diffusion 的正向加噪流程只能 “几乎” 将 x0x_0 变为纯高斯噪声,但模型推理是从纯高斯噪声反推 x0^\hat{x_0} 的。如此,模型从纯高斯噪声推理会失去某些信息,尤其是加噪难以消除的低频信息。

每一步采样所取的高斯噪声几乎确定了模型推理过程的低频均值。更别说像是 Stable Diffusion 在 512x512 图像分辨率时需要面对的 3x64x64=12288 维的随机噪声,低频信息的频率下限显著降低,从而需要相当大的采样步数,才能改变整体的均值。体现到实际图像生成任务中的话,就是模型很容易就生成平均亮度不高不低的图像

为了让模型能够灵活地生成亮图或暗图,我们需要让模型能够学会如何主动修改这个全局均值。解决方法其实挺简单,在训练过程中扰动这个全局均值、强制模型学习如何对抗它。

具体来说,通常的取高斯噪声代码为:

noise = torch.randn_like(latents)

现在在 channel 维度添加一个小的随机扰动,扰动一般乘上 0.1:

noise = torch.randn_like(latents) + 0.1 * torch.randn(latents.shape[0], latents.shape[1], 1, 1)

这就完成了 offset noise 技巧。

immiscible diffusion:施加的噪声与图像充分混合

论文《Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment (2024)》提出,同样是高斯噪声,也有高低之分。以与 x0x_0 的欧氏距离为尺度,有的噪声比较贴近 x0x_0,有的则相距较远。

所以有了这样的训练思路:正向过程若向 x0x_0 添加欧氏距离相对较小的噪声,则能增加模型学习分辨的难度,进而改善训练效率。

论文实际实验下来的确如此。在一个 mini-batch 内将各个 x0x_0 所对应的噪声重新排序配对,使得各 x0x_0 分得与自身欧氏距离相对最小的噪声,可以加快训练速度。

网上搜索了一圈,目前这个 trick 的关注度不是很高。不知道在更广泛的使用中效果如何。

我个人认为这个 trick 应该很有效,更何况实现起来并不麻烦,可以用上。

denoising-diffusion-pytorch 库的代码实现流程如下。输入 x0x_0 x_start 和噪声 noise 后:

  1. x_start noise 除 batch 的维度合并。即 [batch, channel, hight, width] 变为 [batch, n]
    1. 这样一来 x_startnoise 可以视为两个有 batch 个向量的对象
  2. 借助式子 dist(i,j)=(x[i,k]y[j,k])2\mathrm{dist}(i, j)=\sqrt{\sum (x[i,k]-y[j,k])^2},计算向量两两之间的欧氏距离 dist
  3. 使用 scipy.optimize.linear_sum_assignment(),获得配对总距离最小的组合 assign
  4. assign 重新排序 noise

Min-SNR-γ:不同时间步上设定不同的 loss 权重

本节参考:

由于有 xt=αtx0+1αtϵx_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon,可以这样定义 xtx_t 的信噪比:

SNR(t)=αt1αt\mathrm{SNR}(t)=\frac{\overline{\alpha}_t}{1-\overline{\alpha}_t}

有一种叫做 Min-SNR-γ 的 trick,避免模型过分关注 tt 较低、加噪较少时的时间步,可以加快收敛速度。

论文《Efficient Diffusion Training via Min-SNR Weighting Strategy》探究出的结果:

  • 使用 ε-prediction 时,wt=min{SNR(t),γ}/SNR(t)w_t=\min\{\mathrm{SNR}(t),\gamma\} / \mathrm{SNR}(t)
  • 使用 x-prediction 时,wt=min{SNR(t),γ}w_t=\min\{\mathrm{SNR}(t),\gamma\}
  • 使用 v-prediction 时,wt=min{SNR(t),γ}/(SNR(t)+1)w_t=\min\{\mathrm{SNR}(t),\gamma\} / (\mathrm{SNR}(t)+1)

γ\gamma 一般取 5。