本文将会分为两个部分介绍如何使用DDPM生成MNIST手写数字
- 训练过程
- 采样过程
1. 训练过程
1.1 前向扩散
首先回顾一下什么前向扩散,前向扩散是由目标分布转换为已知分布(高斯分布)的过程,用公式描述就是
其中是目标分布,是高斯分布,, 。
是随增大而增大的数,代表着中高斯噪声所占比重或者说原始图像被破坏的程度。
和代码如下
beta_schedule = torch.linspace(1e-4, 0.02, T + 1, device=device)
alpha_t_schedule = 1 - beta_schedule
sqrt_minus_bar_alpha_t_schedule = torch.sqrt(1 - bar_alpha_t_schedule)
前向扩散过程代码如下
noise = torch.randn_like(imgs, device=self.device)
noise_imgs = self.sqrt_bar_alpha_t_schedule[t].view((batch_size, 1, 1 ,1)) * imgs \
+ self.sqrt_minus_bar_alpha_t_schedule[t].view((batch_size, 1, 1, 1)) * noise
1.2 损失函数
损失函数的公式为
是一个带有self-attention的UNet,用于估计当前时间步时图像包含的高斯噪声。
代码为
pred_noise = self.eps_model(noise_imgs, t.unsqueeze(1))
# calculate of Loss simple ||noise - pred_noise||^2, which is MSELoss
self.criterion(pred_noise, noise)
1.3 Time embedding
跟常规UNet不同的是,DDPM的UNet需要两个输入,一个是带有噪声的图像,一个是当前时间步。那么一个整数怎么和一个图像相结合呢?
这里用到了位置编码(position embedding)技术。简单来说就是将一个整数通过某种算法变成向量,然后将向量用加法的形式加到UNet的中间层输出里。
有两种方法将转换成向量,一种是通过学习的方式,比如一个MLP,将数字直接映射为一个向量。第二种是通过编码的方式将数字映射为一个向量。下面介绍第二种方法。
对于t的编码为
是编码后的向量,长度为d。对于的偶数位元素,使用正弦编码,奇数位元素,使用余弦编码。
class PositionalEmbedding(nn.Module):
def __init__(self, T: int, output_dim: int) -> None:
super().__init__()
self.output_dim = output_dim
position = torch.arange(T).unsqueeze(1)
div_term = torch.exp(torch.arange(0, output_dim, 2) * (-math.log(10000.0) / output_dim))
pe = torch.zeros(T, output_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor):
return self.pe[x].reshape(x.shape[0], self.output_dim)
代码中为了避免数字太大溢出,使用了对数变换。
大致实现如下,中间省略了一些无关的代码。
self.time_embedding = nn.Sequential(
PositionalEmbedding(T=T, output_dim=hid_size),
nn.Linear(hid_size, time_emb_dim),
nn.ReLU(),
nn.Linear(time_emb_dim, time_emb_dim)
)
···
h = self.conv_1(x)
t = self.time_emb(t)
return self.conv_3(x) + self.conv_2(h + t)
2. 反向生成
代码实现如下
def sample(self, n_samples, size):
self.eval()
with torch.no_grad():
# get normal noise
x_t = torch.randn(n_samples, *size, device=self.device)
# calculate x_(t-1) on every iteration
for t in range(self.T, 0, -1):
t_tensor = torch.tensor([t], device=self.device).repeat(x_t.shape[0], 1)
# get predicted noise from model
pred_noise = self.eps_model(x_t, t_tensor)
# get some noise to calculate x_(t-1) as in formula (How to get a Noise)
# for t = 0, noise should be 0
z = torch.randn_like(x_t, device=self.device) if t > 0 else 0
# Formula from How to get sample
# x_(t-1) = 1 / sqrt(alpha_t) * (x_t - pred_noise * (1 - alpha_t) / sqrt(1 - alpha_t_bar)) + beta_t * eps
x_t = 1 / torch.sqrt(self.alpha_t_schedule[t]) * \
(x_t - pred_noise * (1 - self.alpha_t_schedule[t]) / self.sqrt_minus_bar_alpha_t_schedule[t]) + \
torch.sqrt(self.beta_schedule[t]) * z
return x_t
这里跟公式(5)有一点不一样的是高斯分布的方差取得是而不是 。论文有讨论过这两种方差都能得到较好的效果,可能是模型对方差选择不敏感。
3. 实验结果
3.1 Loss 曲线
感觉损失下降的过快了,有没训练好的感觉。
3.2 生成图像
大部分还是挺好的,但还是有一些奇怪的形状。