1. 背景
1.1 生成模型
生成模型(Generative Models)是深度学习领域的重要分支,其核心目标是学习数据分布 ( p(x) ),并具备以下能力:
- 数据生成:从隐空间采样生成新样本
- 概率建模:估计观测数据的似然值
- 隐表征学习:发现数据背后的潜在结构
1.2 扩散模型的优势
- 训练稳定性强
- 生成图像质量高
- 数学理论背景扎实
2. 核心概念
扩散模型会将一个高斯分布的样本一步步转化为目标分布的样本(比如一只猫的图像),从高斯分布样本到目标分布样本的每一步都是通过深度学习模型来实现。
理解DDPM主要需要理解两个概念即前向扩散和反向去噪。
而前向扩散过程是利用预先定义的加噪规则,不断将某个分布的样本转化为纯高斯分布的样本。反向去噪是指将高斯噪声样本一步步转化为目标分布样本,其中每一步的去噪操作通过深度学习模型实现。
2.1 前向扩散
前向过程是指将是指对一个目标分布的样本分多步添加高斯噪声,最终使目标分布样本变为高斯分布样本的过程。
下面这张图就表示了通过不断对一张猫的图像添加噪声,最终使一张猫的图片变为纯高斯噪声的图像。
在DDPM中噪声添加方式如下:
- 假设要分T步进行加噪,设置超参数,t代表当前是第几步加噪。随t增大逐渐增大。
- 设
其中是目标分布的样本(上图中的猫),是符合标准高斯分布的样本, 是加过噪声的图像。
当 时,,, 图像没有被添加噪声。
当 时,,, 图像变为完全的高斯噪声。
2.2 反向去噪
相对应的,反向去噪过程就是将一个完全的高斯噪声图像,逐步去噪,变为目标分布图像的过程,即图像生成过程。
这一步通过神经网络来估计出图像带有的噪声,并去掉
就是我们需要训练出来的模型,它接受两个输入,一是带有噪声的图像,二是当前时间步,它需要依据这两项来估计出当前图像带有的高斯噪声。
前向扩散和反向去噪公式的由来均有严格的数学证明,这里由于篇幅受限,就不详细介绍了。
3. 实验流程
3.1 训练流程
- 选择一个目标分布的样本(即训练集里的样本) ,并确定T,按照DDPM论文,T=1000。
- 选择一个深度神经网络来估计噪声。
- 随机生成一个标准高斯分布的样本 。
- 随机生成时间步, 。
- 利用前向扩散计算出带噪声的图像
- 使用模型估计出里的噪声
- 通过最小化损失函数更新网络参数
3.2 生成(采样)流程
- 随机生成一个标准高斯分布的样本 。
- 利用反向去噪过程逐步将里的噪声去掉,最终得到目标分布样本。
4. 代码实现
代码主要包括三部分 1.数据预处理 2.模型搭建,3.训练与测试。
这里仅仅介绍关键代码,全部代码可以在github查看。
4.1 数据预处理
将数据归一化到[-1,1]
transform_to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])
train_dataset = torchvision.datasets.CIFAR10(root="/home/debugwang/data/CIFAR10/", download=False, transform=transform_to_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataset = torchvision.datasets.CIFAR10(root="/home/debugwang/data/CIFAR10/", download=False, transform=transform_to_tensor, train=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
4.2 模型搭建
使用带有多头自注意力机制的UNet。
注意跟常规UNet不同的是,DDPM的UNet需要两个输入,一个是带有噪声的图像,一个是当前时间步。那么一个整数怎么和一个图像相结合呢?
这里用到了位置编码(position embedding)技术,来自transformer。简单来说就是将一个整数通过某种算法变成向量,然后将向量用加法的形式加到UNet的中间层输出里。
对于t的编码为
是编码后的向量,长度为d。对于的偶数位元素,使用正弦编码,奇数位元素,使用余弦编码。
class PositionalEmbedding(nn.Module):
def __init__(self, T: int, output_dim: int) -> None:
super().__init__()
self.output_dim = output_dim
position = torch.arange(T).unsqueeze(1)
div_term = torch.exp(torch.arange(0, output_dim, 2) * (-math.log(10000.0) / output_dim))
pe = torch.zeros(T, output_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor):
return self.pe[x].reshape(x.shape[0], self.output_dim)
代码中为了避免数字太大溢出,使用了对数变换。
4.3 训练与采样
训练与采样按照原论文的流程进行。
4.3.1 训练
def forward(self, imgs):
# random choose some time steps
t = torch.randint(low=1, high=self.T+1, size=(imgs.shape[0],), device=self.device)
# get random noise to add it to the images
noise = torch.randn_like(imgs, device=self.device)
# get noise image as: sqrt(alpha_t_bar) * x0 + noise * sqrt(1 - alpha_t_bar)
batch_size, channels, width, height = imgs.shape
noise_imgs = self.sqrt_bar_alpha_t_schedule[t].view((batch_size, 1, 1 ,1)) * imgs \
+ self.sqrt_minus_bar_alpha_t_schedule[t].view((batch_size, 1, 1, 1)) * noise
# get predicted noise from our model
pred_noise = self.eps_model(noise_imgs, t.unsqueeze(1))
# calculate of Loss simple ||noise - pred_noise||^2, which is MSELoss
return self.criterion(pred_noise, noise)
4.3.2 采样
def sample(self, n_samples, size):
self.eval()
with torch.no_grad():
# get normal noise
x_t = torch.randn(n_samples, *size, device=self.device)
# calculate x_(t-1) on every iteration
for t in range(self.T, 0, -1):
t_tensor = torch.tensor([t], device=self.device).repeat(x_t.shape[0], 1)
# get predicted noise from model
pred_noise = self.eps_model(x_t, t_tensor)
# get some noise to calculate x_(t-1) as in formula (How to get a Noise)
# for t = 0, noise should be 0
z = torch.randn_like(x_t, device=self.device) if t > 0 else 0
# Formula from How to get sample
# x_(t-1) = 1 / sqrt(alpha_t) * (x_t - pred_noise * (1 - alpha_t) / sqrt(1 - alpha_t_bar)) + beta_t * eps
x_t = 1 / torch.sqrt(self.alpha_t_schedule[t]) * \
(x_t - pred_noise * (1 - self.alpha_t_schedule[t]) / self.sqrt_minus_bar_alpha_t_schedule[t]) + \
torch.sqrt(self.beta_schedule[t]) * z
return x_t
5 实验结果
5.1 生成的图像
下面是DDPM原论文生成的图像
有一些还可以,有一些就不太行。主要还是算力不太够,原论文训练了800k steps,我用的是4070TiS,训练了150k steps 就用了一天多,等有钱了一定换个好显卡。
Reference
本文代码主要参考 stable-diffusion-from-scratch
后续计划
后续会继续实现带条件的DDPM。