用 MNIST 训练 Diffusion 模型的代码

378 阅读2分钟

总览

模仿 Stable Diffusion 3 实现的 diffusion 模型,使用 MNIST 数据集进行训练。

代码地址:github.com/HiDolen/dif…

本项目主要用于自己练习,熟悉扩散模型和流匹配模型的原理,以及模型训练过程。

  • 使用 DiT 模型,包含 3 个 Transformer 层
  • embedding 总维度 256,注意力头 4 个

训练 300 个 epoch。

训练过程与结果

网络结构

主要使用了 Stable Diffusion 3 那套 DiT 的网络结构,简单来说是将用于图像生成的 UNet 换为了 Transformer 网络。

其他值得一提的细节:

  • 位置编码使用了类似 FLUX 的二维 RoPE 编码
  • 时间步采用正余弦编码,类别编码则使用 nn.Embedding 生成
  • 用 AdaLayerNorm 层以缩放和偏移的形式将时间步信息施加到 Transformer 中
  • 将类别嵌入视为单个 token 与图像序列一同参与 Transformer 处理,以此施加类别信息

训练时使用以下式子获得 tt 时刻的加噪图像 xtx_t

xt=σtϵ+(1σt)x0x_t=\sigma_t\epsilon+(1-\sigma_t)x_0

其中 x0x_0 为原始图像,σt\sigma_t 为对应时间步 tt 的噪声占比,ϵ\epsilon 为高斯噪声。

损失函数:

loss=MSE(xtϵ,vt)\mathrm{loss}=\mathrm{MSE}(x_t - \epsilon, v_t)

其中 ϵ\epsilon 是从 x0x_0 获得 xtx_t 所使用的高斯噪声,vtv_t 是模型预测结果。

一步采样:

xt=xt1+(σt1σt)vtx_t=x_{t-1}+(\sigma_{t-1} - \sigma_t)v_t

训练经验

  • loss 大小不能直接表明学习效果。低 loss 也可能会出现糟糕的生成结果
  • 过大的学习率会导致训练不稳定。实践下来将学习率设为 2e-4 已经是极限了
  • 可以将 embedding 总维度降到 128。模型仍然能收敛,就是效果会打折扣
  • 训练使用 offset noise 技巧效果不明显,训练使用 immiscible diffusion 技巧效果倒退。可能是不太适合 MNIST
  • 将训练图像的方差归一化到 1 会极大提升训练稳定性

参考来源

感谢这些文章和项目。