代码层面上学习StableDiffusion3

683 阅读3分钟

总览

久闻 StableDiffusion 的大名。现在到第三代了,应当和最开始的 Diffusion 有很大的差别。这次解读会是艰巨的任务。

本文针对文生图任务,对 stabilityai/stable-diffusion-3-medium-diffusers 模型进行代码层面上的运行机理的探索。

最新的有 FLUX 和 SD3.5 模型,做出了一些优化改进。例如位置编码换为了 RoPE,降采样不是使用 2x2 卷积而是 channel 堆叠。本文针对的是 SD3 模型的结构,可以说是老了一代,这该注意。

本文所使用的库版本:

  • transformers==4.46.2
  • diffusers==0.31.0

本文参考:

结构

Stable Diffusion 3 没有采用扩散模型的思路,而是采用了 Flow Matching。最大的优势是,从原理上即使采样步数极致压缩到只剩几步,仍然能获得较好的结果。

然后看看 Stability AI 的文章 有啥亮点。其实文章就只说了自己用一种新架构 MMDiT,强化文本嵌入编码能力。用了两个 CLIP 语言模型和一个 T5 语言模型,得到 pooled_prompt_embeds 和 prompt_embeds 分别控制全程 context 和动态更新当前步的 context。具体看后文。

流匹配模型

不同于通常的扩散模型学习的是逆向噪声,流匹配模型(Flow Matching Models)学习的是逆向速度。噪声分布变换到图像分布的路线变为了直线,该生成过程更适合少步数生成。

在此就略过公式推导。若只想做代码实现,那么获取 xtx_t 有:

xt=σtϵ+(1σt)x0x_t=\sigma_t\epsilon+(1-\sigma_t)x_0

一步采样:

xt=xt1+(σt1σt)vtx_t=x_{t-1}+(\sigma_{t-1} - \sigma_t)v_t

其中 vtv_t 是模型预测的速度方向,指示粒子在 xt1x_{t-1} 时应当如何变换。而 σt\sigma_t 是范围在 [0,1][0,1] 的指示噪声占比的标量。

采样公式对应着 f(t+dt)=f(t)+dtf(t)f(t+\mathrm{d}t)=f(t)+\mathrm{d}t\cdot f'(t)。所以该采样方法被称为欧拉法。

非常直接。

非均匀训练噪声采样

模型进行一次训练时会随机选择某个 tt 让模型进行一次预测。

通常来说,tt 接近于最大值和最小值时较容易预测,所以我们希望模型更多地学习难以预测的中间时刻 tt。于是就有了中间多两边少的非均匀采样来取样 tt

论文经过实验发现 logit-normal(对数正态分布)的方法效果最好。

使用 SD3 所使用的代码

import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
text_encoder = T5EncoderModel.from_pretrained(
    model_id,
    subfolder="text_encoder_3",
    quantization_config=quantization_config,
)
pipe = StableDiffusion3Pipeline.from_pretrained(
    model_id,
    text_encoder_3=text_encoder,
    device_map="balanced",
    torch_dtype=torch.float16,
)

image = pipe(
    prompt="a photo of a cat holding a sign that says hello world",
    negative_prompt="",
    num_inference_steps=28,
    height=1024,
    width=1024,
    guidance_scale=7.0,
).images[0]

images

接下来以 pipe() 为切入点,进入 StableDiffusion3Pipeline.__call__()

StableDiffusion3Pipeline.__call__()

获得 prompt 编码

调用 self.encode_prompt() 用于编码 prompt。获得以下变量:

  • prompt_embeds,[2, 333, 4096]
  • pooled_prompt_embeds,[2, 2048]

全程使用了两个 CLIP 模型和一个 T5 模型。可以不需要 T5 模型的 embedding(会视为零张量),但推荐用上 T5 以获得更好的生成结果。

对 prompt 编码的过程如下:

  1. self._get_clip_prompt_embeds(),获得两种嵌入 pooled_prompt_embedsprompt_embeds
    1. 分词器 self.tokenizertransformers.models.clip.tokenization_clip.CLIPTokenizer
    2. 编码器 self.text_encodertransformers.models.clip.modeling_clip.CLIPTextModelWithProjection
      1. 获得 prompt_embeds。包含了 ['text_embeds', 'last_hidden_state', 'hidden_states'] 三部分
    3. 通过 prompt_embeds.text_embeds 获得编码后的张量 pooled_prompt_embeds
      1. 维度为 [batch, 768]
    4. 通过 prompt_embeds.hidden_states[-2] 获得编码器的中间过程 prompt_embeds
      1. 为什么要取中间结果为 embeds?因为绘图不需要彻底编码的高度具体的嵌入(参考来源:github.com/AUTOMATIC11…
      2. 这个 trick 由参数 clip_skip 控制。值越大,模型接受的嵌入越原始
      3. 维度为 [batch, max_seq_length, 768]([1, 77, 768])
    5. 返回 pooled_prompt_embedsprompt_embeds
  2. self._get_clip_prompt_embeds(),获得两种嵌入 pooled_prompt_2_embedprompt_2_embed
    1. 分词器 self.tokenizer_2transformers.models.clip.tokenization_clip.CLIPTokenizer
    2. 编码器 self.text_encoder_2transformers.models.clip.modeling_clip.CLIPTextModelWithProjection
    3. 这个 CLIP 的 max_seq_length 为 77,dim 为 1280
  3. batch 维度拼接 prompt_embedsprompt_2_embed,获得 clip_prompt_embeds
    1. 拼出来的维度为 [1, 77, 2048]
  4. self._get_t5_prompt_embeds(),获得 t5_prompt_embed
    1. 分词器 self.tokenizer_3transformers.models.t5.tokenization_t5_fast.T5TokenizerFast
    2. 编码器 self.text_encoder_3transformers.models.t5.modeling_t5.T5EncoderModel
    3. 返回编码器结果 prompt_embeds.last_hidden_state。维度为 [batch, max_seq_length, 4096]([1, 256, 4096])
  5. prompt_embeds 使用 pad,用 0 将通道数扩充到与 T5 编码输出 t5_prompt_embed 相同的大小
  6. 获得 prompt_embeds:在 sequence_length 维度拼接 prompt_embedst5_prompt_embed
  7. 获得 pooled_prompt_embeds:在 channel 维度拼接 pooled_prompt_embedspooled_prompt_2_embed
  8. 返回 prompt_embeds pooled_prompt_embeds

若设定了 do_classifier_free_guidance=True,也会对 negative prompt 进行编码。编码过程与 prompt 的编码过程一样,最后会获得 negative_prompt_embedsnegative_prompt_embeds。然后在 batch 维度拼接 prompt 和 negative prompt 的 embedding。

prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

这就完成了 prompt 编码,获得了 prompt_embedspooled_prompt_embeds

获得 timesteps

使用 self.schedulerdiffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler)获得采样步骤。

scheduler.set_timesteps(num_inference_steps)  # 输入采样步数,让 scheduler 准备 timesteps
timesteps = scheduler.timesteps  # 取出 timesteps

调用 scheduler.set_timesteps() 后会设定 self.timestamps,指示每一步所使用的采样步。其中涉及到 sigma 的创建操作,指示某一步上的噪声占比,取值范围 [0, 1]。与 timestamps 的关系是 sigmas = timesteps / num_train_timestepstimestamps 会借助 sigma 进行偏移变换。

本来的 sigma 是纯线性变化的:

纯线性 sigma

通过 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 偏移变换:

变换后的 sigma

意义是,让采样更多集中在低时间步上。或是说,集中在噪声较少的时候。

进行 Denoising 采样

从高斯分布采样出维度 [1, 16, 128, 128] 的矩阵 latent 作为生成图的起点。

由于使用 classifier free guidance,在 batch 维度克隆并拼接一份 latent 作为 latent_model_input

noise_pred:使用 diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel 推理出噪声。

noise_pred 在 batch 维度被分为 noise_pred_uncond noise_pred_text,然后 noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

self.guidance_scale 是用来控制 prompt 参与图像生成过程程度的参数。值 0 ~ 1。

self.schedulerdiffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler),计算 prev_sample = latent + (sigma_next - sigma) * noise_pred,从预测噪声和 latent 获得新的 latent。

就这样,遍历完成所有 timesteps

SD3Transformer2DModel.forward()

  • hidden_states:使用 diffusers.models.embeddings.PatchEmbedhidden_states(图像生成的起点 latent_model_input)施加二维位置编码。维度也会 [2, 16, 128, 128] -> [2, 4096, 1536]
    • Conv2d(16, 1536, kernel_size=2, stride=2)。可见 patch 大小为 2x2
    • 切换维度。[batch, channel, height, width] -> [batch, n, channel]
    • self.cropped_pos_embed(),获得裁剪嵌入
      • 以左上角为坐标原点,计算图像在画布居中时的左上角坐标(预设画布大小为 192x192)
      • self.pos_embed 维度为 [1, 192, 192, 1536]。根据图像坐标取对应位置编码为 spatial_pos_embed
        • 这个二维位置编码 self.pos_embed 是用正余弦的方式生成的。channel 前一半指示 h 轴,后一半指示 y 轴
  • temb:使用 diffusers.models.embeddings.CombinedTimestepTextProjEmbeddings,输入步骤数和 pooled_prompt_embeds,输出综合有步骤数信息和 prompt 信息的嵌入
    • self.time_projdiffusers.models.embeddings.Timesteps)将步骤数处理为正余弦嵌入,维度 256
    • self.timestep_embedderdiffusers.models.embeddings.TimestepEmbedding)让步骤嵌入经过 线性层、SiLU、线性层,并在第二个线性层将维度映射到 256 -> 1536
    • self.text_embedderdiffusers.models.embeddings.PixArtAlphaTextProjection)让 prompt 编码经过 线性层、SiLU、线性层,并在第一个线性层将维度映射到 2048 -> 1536
    • 将刚刚处理得到的步骤嵌入和 prompt 编码相加,返回之
  • encoder_hidden_states:使用线性层,将 prompt_embeds 维度 4096 -> 1536
  • 遍历 self.transformer_blocks(包含若干 diffusers.models.attention.JointTransformerBlock),最终获得目标 hidden_states
    • norm_hidden_states:用 self.norm1diffusers.models.normalization.AdaLayerNormZero),借助 temb 更新 hidden_states 并且生成额外 4 个控制用途的张量 gate_msa shift_mlp scale_mlp gate_mlp
      • temb 经过 SiLU -> 线性层,维度 1536 -> 9216,再切分为 6 个 1536 维度的张量
      • 其中 2 个张量(假设为 aa bb)会经过这样的运算来更新 hidden_states(用 xx 表示):x=LayerNorm(x)(1+a)+bx=\mathrm{LayerNorm}(x)*(1+a)+b
      • 返回跟新后的 hidden_states,和其他 4 个张量 gate_msa shift_mlp scale_mlp gate_mlp
    • norm_encoder_hidden_states:用 self.norm1_contextdiffusers.models.normalization.AdaLayerNormZero),借助 temb 更新 encoder_hidden_states 并且生成额外 4 个控制用途的张量 c_gate_msa c_shift_mlp c_scale_mlp c_gate_mlp
    • 使用 self.attn 对刚刚处理得到的 norm_hidden_states norm_encoder_hidden_states 进行自注意力。两个张量转换为 qkv 后在 sequence 维度进行拼接以进行自注意力(所以不是 cross-attention),获得 attn_output context_attn_output
    • 使用输入到本 forward 的初始 hidden_states 进行残差,且用上最开始生成的控制张量:hidden_states = hidden_states + attn_output * gate_msa
    • 又开始残差 hidden_states
      • hidden_states 应用 LayerNorm,然后使用上最开始生成的控制张量:norm_hidden_states = LN(hidden_states) * (1 + scale_mlp) + shift_mlp
      • ff_output:对 norm_hidden_states 经过 FeedForward:Linear -> GeLU -> Linear,维度 1536 -> 6144 -> 1536
      • hidden_states = hidden_states + gate_mlp * ff_output
    • 使用输入到本 forward 的初始 encoder_hidden_states 进行残差,且用上最开始生成的控制变量:encoder_hidden_states = c_gate_msa * context_attn_output + encoder_hidden_states
    • 又开始残差 encoder_hidden_states
      • norm_encoder_hidden_states = LN(encoder_hidden_states) * (1 + c_scale_mlp) + c_shift_mlp
      • context_ff_output:对 norm_encoder_hidden_states 经过 FeedForward:Linear -> GeLU -> Linear,维度 1536 -> 6144 -> 1536
      • encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
    • 返回 encoder_hidden_states hidden_states
  • temb 更新 hidden_statesdiffusers.models.normalization.AdaLayerNormContinuous
    • emb = Linear(SiLU(temb)),维度加到二倍然后切分为 scale shift
    • hidden_states = hidden_states * (1 + scale) + shift
  • 用 Linear 将 hidden_states 维度 1536 -> 64。bias=True
  • hidden_states 维度 [2, 4096, 64] -> [2, 64, 64, 2, 2, 16] -> [2, 16, 128, 128]

关于两个 prompt embeds,一些总结:

  1. temb 结合了步骤数信息和 pooled_prompt_embeds 信息,在这个 step 内不会被更新
  2. encoder_hidden_states 会在这个 step 中的各个 block 里不断被 temb 和其他处理更新
    1. temb 主要会反复调整 encoder_hidden_states 的均值和方差

额外说明

Classifier-free Guidance:控制生成结果

本节参考:

该以怎样的形式将 conditional 信息注入到扩散模型中,以控制生成结果?

论文《Diffusion Models Beat GANs on Image Synthesis (2021)》首先提出了 Classifier Guidance。大致来说,扩散模型在训练过程中会同时训练一个分类模型。分类模型输入当前时间步的图像 xtx_t 和 conditional 信息 yy、输出控制信息,再将控制信息乘上 scale 加在噪图像预测 x^t1\hat{x}_{t-1} 上。损失函数会多出一项,用于最小化分类模型对 xtx_t 的分类结果 y^\hat{y} 与 conditional 信息所期望的真实分类结果 yy。训练过程中 y^\hat{y} 不一定与 yy 相同,但随着训练推进会越来越准确。

yy 只能有 nn 个取值(nn 是设定的分类数量)。若是超出这个范畴的类别,对不起,模型不能理解。这是 Classifier Guidance 的致命缺点。

有没有办法让模型能接受无穷无尽的、各式各样的 yy,即使模型在训练时没见过?可以,只需要一个外援。Stable Diffusion 请的外援是一个叫做 CLIP 的语言模型。

和死板的分类器不一样,例如某个 CLIP 可以将用户的任意文字提示词转换为长度 77、维度 768 的 Embedding。扩散模型使用这个 Embedding 进行指导生成,形成了 Classifier-free Guidance 效果。依托语言模型的泛化能力,扩散模型能 “想象出” 没见过的画面。

顺便一提 γ\gamma 参数(代码中对应 guidance_scale),用于控制 conditional 信息对生成过程的指导力度。效果大概是下面这个式子,其值越接近 1 则控制力度越大。

result=(1γ)model()+γmodel(y)\mathrm{result}=(1-\gamma)\mathrm{model}(\varnothing) + \gamma\mathrm{model}(y)

adaLN-Zero:将控制信息融入到 Transformer 中

本节参考:

把用于噪声预测的 UNet 换为 Transformer,就有了 DiT。

该如何把控制信息(时间步 t、prompt embedding c)融入到 Transformer 呢?论文《Scalable Diffusion Models with Transformers》提到了多种方法:

  • In-context conditioning:将 tc 视为两个 token,与图像序列一同参与 attention。像是 ViT 一样直接
  • Cross-attention block:tc 通过 cross attention 将信息融入到图像序列中。t c 视为长度为 2 的序列作为 key 和 value,图像序列作为 query
  • adaLN(Adaptive layer norm)block:添加类似于 LayerNorm 层的 adaLN 层调节图像序列,其使用的均值和方差不再是根据输入数据决定,而是由 tc 和 adaLN 层自身的可学习权重确定
  • adaLN-Zero block:实验证明把 adaLN 层的权重初始化为全零有利于训练。另外除了均值 shift 和方差 scale,论文还让 adaLN 层额外决定了残差连接的比例,该值被称为 “gate”

adaLN-Zero block 的具体实现可见以下伪代码。adaLN_modulation 生成了两组 shift、scale 和 gate,分别控制 Transformer 中的 attention 部分和 mlp 部分。

class DiTBlock(nn.Module):
    def __init__():
        ···
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )
        # zero init
        nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.adaLN_modulation[-1].bias, 0)

    def forward(self, x, tc):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(tc)
        x = x + gate_msa * self.attn(self.norm1(x) * (1 + scale_msa) + shift_msa)
        x = x + gate_mlp * self.mlp(self.norm2(x) * (1 + scale_mlp) + shift_mlp)
        return x
        
        
 def modulate(x, shift, scale):
     return x * (1+scale.unsqueeze(1)) + shift.unsqueeze(1)

根据论文的实验结果,adaLN 不仅成功将控制信息添加到了 Transformer,效果还更好,且运算量极低。