总览
久闻 StableDiffusion 的大名。现在到第三代了,应当和最开始的 Diffusion 有很大的差别。这次解读会是艰巨的任务。
本文针对文生图任务,对 stabilityai/stable-diffusion-3-medium-diffusers 模型进行代码层面上的运行机理的探索。
最新的有 FLUX 和 SD3.5 模型,做出了一些优化改进。例如位置编码换为了 RoPE,降采样不是使用 2x2 卷积而是 channel 堆叠。本文针对的是 SD3 模型的结构,可以说是老了一代,这该注意。
本文所使用的库版本:
- transformers==4.46.2
- diffusers==0.31.0
本文参考:
- 周弈帆,“Stable Diffusion 3 论文及源码概览”,zhouyifan.net/2024/07/14/…
- 真神。文章写得详实具体
- arxiv.org/html/2403.0…
结构
Stable Diffusion 3 没有采用扩散模型的思路,而是采用了 Flow Matching。最大的优势是,从原理上即使采样步数极致压缩到只剩几步,仍然能获得较好的结果。
然后看看 Stability AI 的文章 有啥亮点。其实文章就只说了自己用一种新架构 MMDiT,强化文本嵌入编码能力。用了两个 CLIP 语言模型和一个 T5 语言模型,得到 pooled_prompt_embeds 和 prompt_embeds 分别控制全程 context 和动态更新当前步的 context。具体看后文。
流匹配模型
不同于通常的扩散模型学习的是逆向噪声,流匹配模型(Flow Matching Models)学习的是逆向速度。噪声分布变换到图像分布的路线变为了直线,该生成过程更适合少步数生成。
在此就略过公式推导。若只想做代码实现,那么获取 有:
一步采样:
其中 是模型预测的速度方向,指示粒子在 时应当如何变换。而 是范围在 的指示噪声占比的标量。
采样公式对应着 。所以该采样方法被称为欧拉法。
非常直接。
非均匀训练噪声采样
模型进行一次训练时会随机选择某个 让模型进行一次预测。
通常来说, 接近于最大值和最小值时较容易预测,所以我们希望模型更多地学习难以预测的中间时刻 。于是就有了中间多两边少的非均匀采样来取样 。
论文经过实验发现 logit-normal(对数正态分布)的方法效果最好。
使用 SD3 所使用的代码
import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
text_encoder = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder_3",
quantization_config=quantization_config,
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_id,
text_encoder_3=text_encoder,
device_map="balanced",
torch_dtype=torch.float16,
)
image = pipe(
prompt="a photo of a cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
height=1024,
width=1024,
guidance_scale=7.0,
).images[0]
images
接下来以 pipe() 为切入点,进入 StableDiffusion3Pipeline.__call__()。
StableDiffusion3Pipeline.__call__()
获得 prompt 编码
调用 self.encode_prompt() 用于编码 prompt。获得以下变量:
prompt_embeds,[2, 333, 4096]pooled_prompt_embeds,[2, 2048]
全程使用了两个 CLIP 模型和一个 T5 模型。可以不需要 T5 模型的 embedding(会视为零张量),但推荐用上 T5 以获得更好的生成结果。
对 prompt 编码的过程如下:
self._get_clip_prompt_embeds(),获得两种嵌入pooled_prompt_embeds和prompt_embeds- 分词器
self.tokenizer(transformers.models.clip.tokenization_clip.CLIPTokenizer) - 编码器
self.text_encoder(transformers.models.clip.modeling_clip.CLIPTextModelWithProjection)- 获得
prompt_embeds。包含了 ['text_embeds', 'last_hidden_state', 'hidden_states'] 三部分
- 获得
- 通过
prompt_embeds.text_embeds获得编码后的张量pooled_prompt_embeds- 维度为 [batch, 768]
- 通过
prompt_embeds.hidden_states[-2]获得编码器的中间过程prompt_embeds- 为什么要取中间结果为 embeds?因为绘图不需要彻底编码的高度具体的嵌入(参考来源:github.com/AUTOMATIC11… )
- 这个 trick 由参数 clip_skip 控制。值越大,模型接受的嵌入越原始
- 维度为 [batch, max_seq_length, 768]([1, 77, 768])
- 返回
pooled_prompt_embeds和prompt_embeds
- 分词器
self._get_clip_prompt_embeds(),获得两种嵌入pooled_prompt_2_embed和prompt_2_embed- 分词器
self.tokenizer_2(transformers.models.clip.tokenization_clip.CLIPTokenizer) - 编码器
self.text_encoder_2(transformers.models.clip.modeling_clip.CLIPTextModelWithProjection) - 这个 CLIP 的 max_seq_length 为 77,dim 为 1280
- 分词器
- batch 维度拼接
prompt_embeds和prompt_2_embed,获得clip_prompt_embeds- 拼出来的维度为 [1, 77, 2048]
self._get_t5_prompt_embeds(),获得t5_prompt_embed- 分词器
self.tokenizer_3(transformers.models.t5.tokenization_t5_fast.T5TokenizerFast) - 编码器
self.text_encoder_3(transformers.models.t5.modeling_t5.T5EncoderModel) - 返回编码器结果
prompt_embeds.last_hidden_state。维度为 [batch, max_seq_length, 4096]([1, 256, 4096])
- 分词器
- 对
prompt_embeds使用 pad,用 0 将通道数扩充到与 T5 编码输出t5_prompt_embed相同的大小 - 获得
prompt_embeds:在 sequence_length 维度拼接prompt_embeds和t5_prompt_embed - 获得
pooled_prompt_embeds:在 channel 维度拼接pooled_prompt_embeds和pooled_prompt_2_embed - 返回
prompt_embedspooled_prompt_embeds
若设定了 do_classifier_free_guidance=True,也会对 negative prompt 进行编码。编码过程与 prompt 的编码过程一样,最后会获得 negative_prompt_embeds 和 negative_prompt_embeds。然后在 batch 维度拼接 prompt 和 negative prompt 的 embedding。
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
这就完成了 prompt 编码,获得了 prompt_embeds 和 pooled_prompt_embeds。
获得 timesteps
使用 self.scheduler(diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler)获得采样步骤。
scheduler.set_timesteps(num_inference_steps) # 输入采样步数,让 scheduler 准备 timesteps
timesteps = scheduler.timesteps # 取出 timesteps
调用 scheduler.set_timesteps() 后会设定 self.timestamps,指示每一步所使用的采样步。其中涉及到 sigma 的创建操作,指示某一步上的噪声占比,取值范围 [0, 1]。与 timestamps 的关系是 sigmas = timesteps / num_train_timesteps。timestamps 会借助 sigma 进行偏移变换。
本来的 sigma 是纯线性变化的:
通过 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 偏移变换:
意义是,让采样更多集中在低时间步上。或是说,集中在噪声较少的时候。
进行 Denoising 采样
从高斯分布采样出维度 [1, 16, 128, 128] 的矩阵 latent 作为生成图的起点。
由于使用 classifier free guidance,在 batch 维度克隆并拼接一份 latent 作为 latent_model_input。
noise_pred:使用 diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel 推理出噪声。
noise_pred 在 batch 维度被分为 noise_pred_uncond noise_pred_text,然后 noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)。
self.guidance_scale是用来控制 prompt 参与图像生成过程程度的参数。值 0 ~ 1。
用 self.scheduler(diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler),计算 prev_sample = latent + (sigma_next - sigma) * noise_pred,从预测噪声和 latent 获得新的 latent。
就这样,遍历完成所有 timesteps。
SD3Transformer2DModel.forward()
hidden_states:使用diffusers.models.embeddings.PatchEmbed对hidden_states(图像生成的起点latent_model_input)施加二维位置编码。维度也会 [2, 16, 128, 128] -> [2, 4096, 1536]- Conv2d(16, 1536, kernel_size=2, stride=2)。可见 patch 大小为 2x2
- 切换维度。[batch, channel, height, width] -> [batch, n, channel]
self.cropped_pos_embed(),获得裁剪嵌入- 以左上角为坐标原点,计算图像在画布居中时的左上角坐标(预设画布大小为 192x192)
self.pos_embed维度为 [1, 192, 192, 1536]。根据图像坐标取对应位置编码为spatial_pos_embed- 这个二维位置编码
self.pos_embed是用正余弦的方式生成的。channel 前一半指示 h 轴,后一半指示 y 轴
- 这个二维位置编码
temb:使用diffusers.models.embeddings.CombinedTimestepTextProjEmbeddings,输入步骤数和pooled_prompt_embeds,输出综合有步骤数信息和 prompt 信息的嵌入- 用
self.time_proj(diffusers.models.embeddings.Timesteps)将步骤数处理为正余弦嵌入,维度 256 - 用
self.timestep_embedder(diffusers.models.embeddings.TimestepEmbedding)让步骤嵌入经过 线性层、SiLU、线性层,并在第二个线性层将维度映射到 256 -> 1536 - 用
self.text_embedder(diffusers.models.embeddings.PixArtAlphaTextProjection)让 prompt 编码经过 线性层、SiLU、线性层,并在第一个线性层将维度映射到 2048 -> 1536 - 将刚刚处理得到的步骤嵌入和 prompt 编码相加,返回之
- 用
encoder_hidden_states:使用线性层,将prompt_embeds维度 4096 -> 1536- 遍历
self.transformer_blocks(包含若干diffusers.models.attention.JointTransformerBlock),最终获得目标hidden_statesnorm_hidden_states:用self.norm1(diffusers.models.normalization.AdaLayerNormZero),借助temb更新hidden_states并且生成额外 4 个控制用途的张量gate_msashift_mlpscale_mlpgate_mlptemb经过 SiLU -> 线性层,维度 1536 -> 9216,再切分为 6 个 1536 维度的张量- 其中 2 个张量(假设为 )会经过这样的运算来更新
hidden_states(用 表示): - 返回跟新后的
hidden_states,和其他 4 个张量gate_msashift_mlpscale_mlpgate_mlp
norm_encoder_hidden_states:用self.norm1_context(diffusers.models.normalization.AdaLayerNormZero),借助temb更新encoder_hidden_states并且生成额外 4 个控制用途的张量c_gate_msac_shift_mlpc_scale_mlpc_gate_mlp- 使用
self.attn对刚刚处理得到的norm_hidden_statesnorm_encoder_hidden_states进行自注意力。两个张量转换为 qkv 后在 sequence 维度进行拼接以进行自注意力(所以不是 cross-attention),获得attn_outputcontext_attn_output - 使用输入到本 forward 的初始
hidden_states进行残差,且用上最开始生成的控制张量:hidden_states = hidden_states + attn_output * gate_msa - 又开始残差
hidden_states:- 对
hidden_states应用 LayerNorm,然后使用上最开始生成的控制张量:norm_hidden_states = LN(hidden_states) * (1 + scale_mlp) + shift_mlp ff_output:对norm_hidden_states经过 FeedForward:Linear -> GeLU -> Linear,维度 1536 -> 6144 -> 1536hidden_states = hidden_states + gate_mlp * ff_output
- 对
- 使用输入到本 forward 的初始
encoder_hidden_states进行残差,且用上最开始生成的控制变量:encoder_hidden_states = c_gate_msa * context_attn_output + encoder_hidden_states - 又开始残差
encoder_hidden_states:norm_encoder_hidden_states = LN(encoder_hidden_states) * (1 + c_scale_mlp) + c_shift_mlpcontext_ff_output:对norm_encoder_hidden_states经过 FeedForward:Linear -> GeLU -> Linear,维度 1536 -> 6144 -> 1536encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
- 返回
encoder_hidden_stateshidden_states
- 用
temb更新hidden_states(diffusers.models.normalization.AdaLayerNormContinuous)emb = Linear(SiLU(temb)),维度加到二倍然后切分为scaleshifthidden_states = hidden_states * (1 + scale) + shift
- 用 Linear 将
hidden_states维度 1536 -> 64。bias=True hidden_states维度 [2, 4096, 64] -> [2, 64, 64, 2, 2, 16] -> [2, 16, 128, 128]
关于两个 prompt embeds,一些总结:
temb结合了步骤数信息和pooled_prompt_embeds信息,在这个 step 内不会被更新encoder_hidden_states会在这个 step 中的各个 block 里不断被temb和其他处理更新temb主要会反复调整encoder_hidden_states的均值和方差
额外说明
Classifier-free Guidance:控制生成结果
本节参考:
该以怎样的形式将 conditional 信息注入到扩散模型中,以控制生成结果?
论文《Diffusion Models Beat GANs on Image Synthesis (2021)》首先提出了 Classifier Guidance。大致来说,扩散模型在训练过程中会同时训练一个分类模型。分类模型输入当前时间步的图像 和 conditional 信息 、输出控制信息,再将控制信息乘上 scale 加在噪图像预测 上。损失函数会多出一项,用于最小化分类模型对 的分类结果 与 conditional 信息所期望的真实分类结果 。训练过程中 不一定与 相同,但随着训练推进会越来越准确。
只能有 个取值( 是设定的分类数量)。若是超出这个范畴的类别,对不起,模型不能理解。这是 Classifier Guidance 的致命缺点。
有没有办法让模型能接受无穷无尽的、各式各样的 ,即使模型在训练时没见过?可以,只需要一个外援。Stable Diffusion 请的外援是一个叫做 CLIP 的语言模型。
和死板的分类器不一样,例如某个 CLIP 可以将用户的任意文字提示词转换为长度 77、维度 768 的 Embedding。扩散模型使用这个 Embedding 进行指导生成,形成了 Classifier-free Guidance 效果。依托语言模型的泛化能力,扩散模型能 “想象出” 没见过的画面。
顺便一提 参数(代码中对应 guidance_scale),用于控制 conditional 信息对生成过程的指导力度。效果大概是下面这个式子,其值越接近 1 则控制力度越大。
adaLN-Zero:将控制信息融入到 Transformer 中
本节参考:
- 梦游娃娃,“扩散模型(十一)| Transformer-based Diffusion:DiT”,lichtung612.github.io/posts/11-di…
- “Scalable Diffusion Models with Transformers”,arxiv.org/pdf/2212.09…
把用于噪声预测的 UNet 换为 Transformer,就有了 DiT。
该如何把控制信息(时间步 t、prompt embedding c)融入到 Transformer 呢?论文《Scalable Diffusion Models with Transformers》提到了多种方法:
- In-context conditioning:将
t和c视为两个 token,与图像序列一同参与 attention。像是 ViT 一样直接 - Cross-attention block:
t和c通过 cross attention 将信息融入到图像序列中。tc视为长度为 2 的序列作为 key 和 value,图像序列作为 query - adaLN(Adaptive layer norm)block:添加类似于 LayerNorm 层的 adaLN 层调节图像序列,其使用的均值和方差不再是根据输入数据决定,而是由
t、c和 adaLN 层自身的可学习权重确定 - adaLN-Zero block:实验证明把 adaLN 层的权重初始化为全零有利于训练。另外除了均值 shift 和方差 scale,论文还让 adaLN 层额外决定了残差连接的比例,该值被称为 “gate”
adaLN-Zero block 的具体实现可见以下伪代码。adaLN_modulation 生成了两组 shift、scale 和 gate,分别控制 Transformer 中的 attention 部分和 mlp 部分。
class DiTBlock(nn.Module):
def __init__():
···
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# zero init
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, tc):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(tc)
x = x + gate_msa * self.attn(self.norm1(x) * (1 + scale_msa) + shift_msa)
x = x + gate_mlp * self.mlp(self.norm2(x) * (1 + scale_mlp) + shift_mlp)
return x
def modulate(x, shift, scale):
return x * (1+scale.unsqueeze(1)) + shift.unsqueeze(1)
根据论文的实验结果,adaLN 不仅成功将控制信息添加到了 Transformer,效果还更好,且运算量极低。