参考文章
Zhouyifan : 扩散模型(Diffusion Model)详解:直观理解、数学原理、PyTorch 实现
苏剑林:[生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼]
DDPM程序实现
本文是根据上述参考文章总结而成,这几篇文章写地非常详细,通俗易懂,强烈建议阅读。
DDPM程序结构
DDPM的整体过程是:首先,我们要随机选取训练图片,随机生成当前要训练的时刻,以及随机生成一个生成的高斯噪声。之后,我们把和输入进神经网络,尝试预测噪声。最后,我们以预测噪声和实际噪声的均方误差为损失函数做梯度下降。
def train(ddpm: DDPM, net, device, ckpt_path):
# n_steps 就是公式里的 T
# net 是某个继承自 torch.nn.Module 的神经网络
n_steps = ddpm.n_steps
dataloader = get_dataloader(batch_size)
net = net.to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)
for e in range(n_epochs):
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net.state_dict(), ckpt_path)
我们首先完成对于训练数据、以及噪声的采样。
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)
随后是根据采样得到的t以及噪声eps得到,随后利用神经网络net去预测我们所添加的噪声为eps-theta
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)
得到了loss之后则可以对net神经网络进行优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
针对去噪神经网络net常用为Unet结构,其中对t的编码通常为Transformer里面的位置编码,并将对应的编码添加到图像中去。
程序为Unet程序里面的forward函数,可以看到t的编码pe是如何被添加到模型中去。
def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
encoder_outs = []
# encoder
for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,
self.downs):
# t ~ > pe
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = encoder(x + pe)
encoder_outs.append(x)
x = down(x)
pe = self.pe_mid(t).reshape(n, -1, 1, 1)
# mid
x = self.mid(x + pe)
# decoder
for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,
self.decoders, self.ups,
encoder_outs[::-1]):
pe = pe_linear(t).reshape(n, -1, 1, 1)
...
x = decoder(x + pe)
x = self.conv_out(x)
return x
DDPM理论推导(简化版)
根据上面的程序结构,我们可以很容易地归纳出DDPM中只需要一个Unet来预测我们所添加的噪声eps-theta将其与添加的噪声eps,则loss = nn.MSELoss(eps_theta,eps)。
这个loss和常见的回归任务预测的nn.MSELoss(y_pred,y)是完全一致的。下面是对于这loss的更加详细的推导过程。 对应的推导来自于zhouyifan Blog
我们之前的优化函数是让加噪音过程以及去噪音过程这两个更加接近,其等价于eps以及eps_theta更加接近。而这个并不够准确,我们的最终目标是要最大化,也就是经过了向前加噪声以及向后去噪声的过程之后,被还原得到的可能性最大,这一步是和VAE模型的目标是一致的。
对于最大化,我们将这个目标转化为最小化 。根据VAE模型中的推理,我们将 展开得到
上述公式中,总共为三项,第一项与Unet中的可学习参数θ无关可以被忽略,剩余两项分别为:
- 表示的是最大化每一个去噪声操作和加噪声逆操作的相似度,其中
p代表为去噪声过程,q代表为及噪声过程。 - 代表已知让复原得到的概率更高。
随后将上述两项进行简化得到这两项都正比于 ,这一步详细过程见link
因此我们可以将最大化与让预测噪声与添加噪声更相近等价到一起,从而推导出了loss = nn.MSELoss(eps_theta,eps)可以作为最终的损失函数。