总览
模仿 Stable Diffusion 3 实现的 diffusion 模型,使用 MNIST 数据集进行训练。
本项目主要用于自己练习,熟悉扩散模型和流匹配模型的原理,以及模型训练过程。
- 使用 DiT 模型,包含 3 个 Transformer 层
- embedding 总维度 256,注意力头 4 个
训练 300 个 epoch。
网络结构
主要使用了 Stable Diffusion 3 那套 DiT 的网络结构,简单来说是将用于图像生成的 UNet 换为了 Transformer 网络。
其他值得一提的细节:
- 位置编码使用了类似 FLUX 的二维 RoPE 编码
- 时间步采用正余弦编码,类别编码则使用
nn.Embedding生成 - 用 AdaLayerNorm 层以缩放和偏移的形式将时间步信息施加到 Transformer 中
- 将类别嵌入视为单个 token 与图像序列一同参与 Transformer 处理,以此施加类别信息
训练时使用以下式子获得 时刻的加噪图像 :
其中 为原始图像, 为对应时间步 的噪声占比, 为高斯噪声。
损失函数:
其中 是从 获得 所使用的高斯噪声, 是模型预测结果。
一步采样:
训练经验
- loss 大小不能直接表明学习效果。低 loss 也可能会出现糟糕的生成结果
- 过大的学习率会导致训练不稳定。实践下来将学习率设为 2e-4 已经是极限了
- 可以将 embedding 总维度降到 128。模型仍然能收敛,就是效果会打折扣
- 训练使用 offset noise 技巧效果不明显,训练使用 immiscible diffusion 技巧效果倒退。可能是不太适合 MNIST
- 将训练图像的方差归一化到 1 会极大提升训练稳定性
参考来源
感谢这些文章和项目。
- github.com/owenliang/m…
- github.com/TongTong313…
- 周弈帆,“Stable Diffusion 3 论文及源码概览”,zhouyifan.net/2024/07/14/…
- 周弈帆,“Stable Diffusion 3「精神续作」FLUX.1 源码深度前瞻解读”,zhouyifan.net/2024/09/03/…
- 来自 Stability AI 团队的 Stable Diffusion,以及来自 Black Forest Labs 团队的 FLUX.1
- Huggingface 的 diffusers 库:github.com/huggingface…
- “Scaling Rectified Flow Transformers for High-Resolution Image Synthesis”,arxiv.org/pdf/2112.10…