StableDiffusion
引言
Stable Diffusion作为目前最流行的开源生图模型之一,良好的生成效果和较低的硬件要求让他在开源社区取得了广泛应用。本文将深入解析Stable Diffusion的模型架构,帮助更好地理解其工作原理。
生图流程
1、通过TextEncoder得到text embedding注入context
2、随机噪声得到隐变量latent
3、编码时间步到time embedding
4、迭代经过unet,条件判断符合后输出到VAE
5、VAE解码输出普通图像
核心组件
1、VAE(变分自编码器)
VAE在StableDiffusion中的主要定位是重建和压缩图像,由编码、解码两个模块组成。
- 编码器,将原始图像 nchw 编码至潜空间 (n,4,64,64)
- 解码器,将潜空间图像解码至普通图像
2、Text Encoder(文本编码器)
TextEncoder负责将提示文本Prompt编码至文本潜空间(图1使用Clip Text Encoder),TextEmbedding作为单次U-Net Context,嵌入Transformer模块的Query,关联隐变量,指导图像扩散。
3、U-Net
U-Net作为StableDiffusion的核心组件,负责在潜空间对图像降噪推理。
U-Net架构
U-Net按输入输出,可以拆分为InputBlock,OuputBlock两部分,他们主要由ResBlock、AttentionBlock两部分构成。
我们查看U-Net源码,位置在 github.com:CompVis/stable-diffusion.git,ldm/modules/diffusionmodules/openaimodel.py
1、首先查看U-Net的InputBlock初始化,可以很直观地看到是一个循环迭代的架构。第二层循环内,首先插入ResBlock,然后根据状态插入AttentionBlock或者SpatialTransformer
2、其次,可以看到每一个input_block都会有一个时间步嵌入参数,作为ResBlock的输入,用于区分隐变量的层次,指导扩散。
3、查看U-Net forward过程
1)输入x是隐变量latent,context是text embedding,y是类别条件(可选项)
2)求时间步嵌入矩阵
3)判断是否打开类别条件,如果开,则求类别emb,直接加到时间步embedding
4)定义module返回值数组hs,存入InputBlock的执行结果
5)执行middle_block,并返回,middle_block包含了两个ResBlock,一个AttentionBlock
6)在output_block中,通过hs连接了input_block的中间输出,防止output_block过程中梯度消失
4、查看ResBlock
1)self.in_layers,顾名思义,是输入序列。先过normalization,再过激活,再走一个conv。
2)根据入参updown,选择上采样或者下采样,InputBlock属于下采样阶段,OutputBlock属于上采样阶段。
3)self.emb_layers,emb输入序列,先走一个激活,再做线性投影。
4)self.out_layers,输出序列,这里的zero_module是把conv_nd中的参数detach后置零,影响首次前向输出,不影响反向传播。
5、观察ResBlock forward过程
1)这里有两个输入,latent和时间步embedding
2)可以很清晰看到,emb_out是怎样被嵌入到hidden_state
6、查看AttentionBlock
1)前置归一化
2)qkv投影用conv1*1代替了gemm
3)注意力计算self.attention
8、查看SpatialTransformer
1)这里的in_channels其实是attention的输入,n_heads是头的数量,d_head是单个头的维度
2)前置了norm,似乎最新的attention都有了前置norm
3)kernel_size为1的conv,功能就是gemm,做了一个输入投影,使输入符合attention的维度
9、查看BasicTransformerBlock(这是SpatialTransformer的核心模块)
1)观察前向过程,先做了一次自注意力,用残差连接
2)再用自注意力的输出,以latent作为query,text_embedding作为key、value,做了一次cross_attention
3)最后走一个常规FFN后输出
总结
-
主要流程:
1)通过 Text Encoder (如CLIP、ChatGLM) 将文本转换为embedding注入context
2)从随机噪声生成初始latent隐变量
3)将时间步编码为time embedding
4)通过U-Net进行迭代降噪
5)最后用VAE将latent转换为最终图像 -
核心组件:
1)VAE:负责图像压缩和重建,包含编码器(图像→latent)和解码器(latent→图像)
2)Text Encoder:将prompt编码为文本embeddings
3)U-Net:核心降噪网络,包含:- Input/Output Block:由 ResBlock 和 Attention 模块组成
- 时间步embedding在ResBlock中注入
- SpatialTransformer中的BasicTransformerBlock通过cross-attention机制将文本特征与图像特征关联
-
特点:
1)在潜空间而不是像素空间进行扩散(基于性能,原图太大)
2)使用 cross-attention 机制实现文本引导
3)采用 U-Net 结构配合残差连接