扩散模型(Diffusion Model)是继GAN、VAE后的一种生成式模型,而目前在文生图领域比较流行的工具,如DALL-E2、Imagen、Stable Diffusion等,均是以上述扩散模型为基础,不断进行算法优化、迭代,取得了令人惊艳的效果。
DDPM
扩散模型于2015年在论文《Deep unsupervised learning using nonequilibrium thermodynamics》中被提出,并于2020年在论文《Denoising diffusion probabilistic models》中被改进、用于图片生成。《Denoising Diffusion Probabilistic Models》中提出的扩散模型被称为DDPM。
正向扩散过程
令原始图片样本为,其满足分布。定义前向扩散过程,在步内,每步给样本增加一个小的满足高斯分布的噪声,从而产生个带噪声的样本,整个过程为一个一阶马尔可夫过程,只与有关,可用以下公式表示:
其中,表示给定时,的条件概率,即均值为、方差为的高斯分布,集合用于控制每步的噪声大小。进一步给定时,整个马尔科夫过程的条件概率为各步条件概率的连乘,可用以下公式表示: 正向扩散过程可由图1从右到左的过程表示,其中为原始图片,随着每步增加噪声,图片逐渐变得模糊。
对于上述正向扩散过程,可进一步令,且,则可用以下公式表示:
即是在的基础上,增加一个满足高斯分布的噪声,循环递归,即可进一步推导为在的基础上,增加一个满足高斯分布的噪声。这里使用了高斯分布的一个特性,即两个高斯分布合并后仍是一个高斯分布,例如分布和,合并后的分布为。
反向扩散过程
以上介绍了正向扩散过程,即图1从右到左,对原始图片逐步增加噪声,如果将过程逆向,即图1从左到右,那么就能从满足高斯分布的噪音逐步还原原始图片样本,这就是基于扩散模型生成图片的基本思想,即从到的每一步,在给定时,根据条件概率采样求解,直至最终得到。 而当正向扩散过程每步增加的噪声很小时,反向扩散过程的条件概率也可以认为满足高斯分布,但实际上,我们不能直接求解该条件概率,因为直接求解需要整体数据集合。除直接求解外,另一个方法是训练一个模型近似预估上述条件概率,可用以下公式表示:
从到的每一步,通过模型,输入和,预测的高斯分布的均值和方差,基于预测值,可以从的高斯分布中进行采样,从而得到的一个可能取值,如此循环,直至最终得到的一个可能取值。通过上述反向扩散过程,即可以实现从一个满足高斯分布的随机噪声,生成一张图片。而由于每次预测均是从一个概率密度函数中进行采样,因此,可以保证生成图片的多样性。 更进一步,论文进一步将模型预测的均值和方差转化为预测噪声,并推导出和的关系:
因此,可表示为:
论文将固定为常量,通过模型预测,并使用上述公式的概率密度函数进行采样,从降噪得到。
模型结构
DDPM中预测的模型基于OpenAI于2017年发布的一个U-Net形式的网络结构PixelCNN++。U-Net于2015年在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中发布,起初主要用于医学图像的切割,目前作为常用的去噪结构,广泛应用于扩散模型中。
U-Net的网络结构如图2所示,因整体结构形似字母U而得名,U型左侧是4层编码器层,对图片进行降维,U型右侧是4层解码器层,对图片进行升维。编码器的每一层,先是连续的两个卷积层(卷积核维度为3×3)和ReLU层,再接一个池化层进行下采样,然后输入下一层,卷积层的通道数逐层加倍,例如,第一层的输入是572×572的单通道图片,经过两个卷积核维度为3×3、通道为64、无padding的卷积层,输出张量维度分别为570×570×64、568×568×64,再经过一个维度为2×2的最大池化层后,输出张量维度为284×284×64,如此循环,最后一层输出的张量维度为32×32×512。在编码器层和解码器层之间的中间层,经过两个卷积层(卷积核维度为3×3、通道为1024)和ReLU层,输出的张量维度为28×28×1024。解码器的每一层,和编码器类似,也先是连续的两个卷积层(卷积核维度为3×3)和ReLU层,和编码器不同的是,解码器从下层到上层,通过一个上卷积层(卷积核维度为3×3)进行上采样(长、宽维度加倍,但通道缩小),同时,解码器每层的输入除上一层上采样的输出外,还包括同层编码器输出的裁剪。例如,解码器第一层的输入,包括中间层上采样的输出,张量维度为56×56×512,和同层编码器输出的裁剪,张量维度为56×56×512,合并后的张量维度为56×56×1024,经过两个卷积核维度为3×3、通道为512的卷积层,输出张量维度分别为54×54×512、52×52×512。解码器最后一层的输出,再通过一个1×1的卷积层,将原先的64通道映射为指定的通道,因为原始U-Net用于图像切割,即对图像每个像素做分类,所以有多少个分类,即有多少个最终的通道。
而PixelCNN++于2017年在论文《PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications》中发布,其网络结构如图3所示。图中,矩形区块对应于U-Net中的编码器或解码器层,共3个编码器层、3个解码器层。在每个编码器或解码器中,PixelCNN++在原U-Net两个卷积层的基础上,增加了一个残差连接。DDPM进一步进行网络结构的改进,包括:使用Group Normalization进行归一化;在残差卷积块后增加自注意力层;使用Transformer中Sinusoidal Position Embedding对步数编码成Embedding向量作为模型输入。 DDPM的代码开源,代码地址是:github.com/hojonathanh…,其深度学习框架采用Tensorflow,计算资源采用Google Cloud TPU v3-8。diffusion_tf/models/unet.py中定义了网络结构,核心代码如下所示(增加了部分注释):
with tf.variable_scope(name, reuse=reuse):
# Timestep embedding
# 将步数t编码成Embedding向量
with tf.variable_scope('temb'):
temb = nn.get_timestep_embedding(t, ch)
temb = nn.dense(temb, name='dense0', num_units=ch * 4)
temb = nn.dense(nonlinearity(temb), name='dense1', num_units=ch * 4)
assert temb.shape == [B, ch * 4]
# Downsampling
# 多层编码器层
hs = [nn.conv2d(x, name='conv_in', num_units=ch)]
for i_level in range(num_resolutions):
with tf.variable_scope('down_{}'.format(i_level)):
# Residual blocks for this resolution
# 构造编码器层,残差卷积块+自注意力层,并进行下采样
for i_block in range(num_res_blocks):
h = resnet_block(
hs[-1], name='block_{}'.format(i_block), temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout)
if h.shape[1] in attn_resolutions:
h = attn_block(h, name='attn_{}'.format(i_block), temb=temb)
hs.append(h)
# Downsample
if i_level != num_resolutions - 1:
hs.append(downsample(hs[-1], name='downsample', with_conv=resamp_with_conv))
# Middle
# 中间层,残差卷积块+自注意力层+残差卷积块
with tf.variable_scope('mid'):
h = hs[-1]
h = resnet_block(h, temb=temb, name='block_1', dropout=dropout)
h = attn_block(h, name='attn_1'.format(i_block), temb=temb)
h = resnet_block(h, temb=temb, name='block_2', dropout=dropout)
# Upsampling
# 多层解码器层
for i_level in reversed(range(num_resolutions)):
with tf.variable_scope('up_{}'.format(i_level)):
# Residual blocks for this resolution
# 构造解码器层,残差卷积块+自注意力层,并进行上采样
for i_block in range(num_res_blocks + 1):
h = resnet_block(tf.concat([h, hs.pop()], axis=-1), name='block_{}'.format(i_block),
temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout)
if h.shape[1] in attn_resolutions:
h = attn_block(h, name='attn_{}'.format(i_block), temb=temb)
# Upsample
if i_level != 0:
h = upsample(h, name='upsample', with_conv=resamp_with_conv)
assert not hs
# End
# 最后再经过一个卷积层输出
h = nonlinearity(normalize(h, temb=temb, name='norm_out'))
h = nn.conv2d(h, name='conv_out', num_units=out_ch, init_scale=0.)
assert h.shape == x.shape[:3] + [out_ch]
return h
其中,残差卷积块的代码如下(增加了部分注释):
def resnet_block(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout):
B, H, W, C = x.shape
if out_ch is None:
out_ch = C
with tf.variable_scope(name):
h = x
# 对图片进行归一化和非线性转化
h = nonlinearity(normalize(h, temb=temb, name='norm1'))
# 对图片进行卷积
h = nn.conv2d(h, name='conv1', num_units=out_ch)
# add in timestep embedding
# 对步数t的embedding向量进行非线性转化,并合并至图片
h += nn.dense(nonlinearity(temb), name='temb_proj', num_units=out_ch)[:, None, None, :]
# 对合并图片和步数后的输入再进行归一化和非线性转化,并再进行卷积
h = nonlinearity(normalize(h, temb=temb, name='norm2'))
h = tf.nn.dropout(h, rate=dropout)
h = nn.conv2d(h, name='conv2', num_units=out_ch, init_scale=0.)
# 对两次卷积后的输出和原始输入进行残差连接
if C != out_ch:
if conv_shortcut:
x = nn.conv2d(x, name='conv_shortcut', num_units=out_ch)
else:
x = nn.nin(x, name='nin_shortcut', num_units=out_ch)
assert x.shape == h.shape
print('{}: x={} temb={}'.format(tf.get_default_graph().get_name_scope(), x.shape, temb.shape))
return x + h
自注意力层的代码如下(即Transformer中的缩放点积注意力层):
def attn_block(x, *, name, temb):
B, H, W, C = x.shape
with tf.variable_scope(name):
h = normalize(x, temb=temb, name='norm')
q = nn.nin(h, name='q', num_units=C)
k = nn.nin(h, name='k', num_units=C)
v = nn.nin(h, name='v', num_units=C)
w = tf.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5))
w = tf.reshape(w, [B, H, W, H * W])
w = tf.nn.softmax(w, -1)
w = tf.reshape(w, [B, H, W, H, W])
h = tf.einsum('bhwHW,bHWc->bhwc', w, v)
h = nn.nin(h, name='proj_out', num_units=C, init_scale=0.)
assert h.shape == x.shape
print(tf.get_default_graph().get_name_scope(), x.shape)
return x + h
训练采样
训练
通过模型预测误差,损失函数采用均方误差(MSE,Mean-Squared Error):
模型训练的目标即最小化上述损失函数,即使模型预测出的噪声和真实噪声尽可能接近。
训练算法如图4所示,采用梯度下降算法,循环下述过程直至模型收敛:
- 对于样本,从中随机采样步数;
- 从高斯分布中采样真实噪声;
- 根据样本和真实噪声,使用前面推导出的公式计算第步正向扩散后带噪声的图片;
- 根据带噪声的图片和步数,使用模型预测噪声,即;
- 根据真实噪声和预测噪声计算损失函数的梯度,即;
- 根据梯度和学习率超参更新模型参数。
采样
采样算法如图5所示,过程如下:
- 从高斯分布中采样完全噪声图片;
- 循环步,步数从到1,直至计算得到,生成最终的图片,对于其中的某一步:
- 根据带噪声的图片和步数,使用模型预测噪声;
- 前面已推导出概率密度函数满足高斯分布,使用公式由噪声计算均值;
- 对于上述概率密度函数,指定方差为常量,根据该分布进行采样,从得到,即。