扩散模型的程序结构以及Loss的简单推理

1,362 阅读3分钟

参考文章
Zhouyifan : 扩散模型(Diffusion Model)详解:直观理解、数学原理、PyTorch 实现
苏剑林:[生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼]
DDPM程序实现
本文是根据上述参考文章总结而成,这几篇文章写地非常详细,通俗易懂,强烈建议阅读。

DDPM程序结构

DDPM的整体过程是:首先,我们要随机选取训练图片X0X_{0},随机生成当前要训练的时刻tt,以及随机生成一个生成XtX_{t}的高斯噪声。之后,我们把XtX_{t}tt输入进神经网络,尝试预测噪声。最后,我们以预测噪声和实际噪声的均方误差为损失函数做梯度下降。

def train(ddpm: DDPM, net, device, ckpt_path):
    # n_steps 就是公式里的 T
    # net 是某个继承自 torch.nn.Module 的神经网络
    n_steps = ddpm.n_steps
    dataloader = get_dataloader(batch_size)
    net = net.to(device)
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), 1e-3)

    for e in range(n_epochs):
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
            eps = torch.randn_like(x).to(device)
            x_t = ddpm.sample_forward(x, t, eps)
            eps_theta = net(x_t, t.reshape(current_batch_size, 1))
            loss = loss_fn(eps_theta, eps)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    torch.save(net.state_dict(), ckpt_path)

我们首先完成对于训练数据X0X_{0}tt以及噪声的采样。

for x, _ in dataloader:
    current_batch_size = x.shape[0]
    x = x.to(device)
    t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
    eps = torch.randn_like(x).to(device)

随后是根据采样得到的t以及噪声eps得到XtX_{t},随后利用神经网络net去预测我们所添加的噪声为eps-theta

x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)

得到了loss之后则可以对net神经网络进行优化

optimizer.zero_grad() 
loss.backward() 
optimizer.step()

针对去噪神经网络net常用为Unet结构,其中对t的编码通常为Transformer里面的位置编码,并将对应的编码添加到图像中去。
程序为Unet程序里面的forward函数,可以看到t的编码pe是如何被添加到模型中去。

    def forward(self, x, t):
        n = t.shape[0]
        t = self.pe(t)
        encoder_outs = []
        
        # encoder
        for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,
                                            self.downs):
            # t ~ > pe 
            pe = pe_linear(t).reshape(n, -1, 1, 1)
            x = encoder(x + pe)
            encoder_outs.append(x)
            x = down(x)
        pe = self.pe_mid(t).reshape(n, -1, 1, 1)
        
        # mid
        x = self.mid(x + pe)
        
        # decoder
        for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,
                                                       self.decoders, self.ups,
                                                       encoder_outs[::-1]):
            pe = pe_linear(t).reshape(n, -1, 1, 1)
            
            ...
            
            x = decoder(x + pe)
        x = self.conv_out(x)
        return x

DDPM理论推导(简化版)

根据上面的程序结构,我们可以很容易地归纳出DDPM中只需要一个Unet来预测我们所添加的噪声eps-theta将其与添加的噪声eps,则loss = nn.MSELoss(eps_theta,eps)
这个loss和常见的回归任务预测的nn.MSELoss(y_pred,y)是完全一致的。下面是对于这loss的更加详细的推导过程。 对应的推导来自于zhouyifan Blog

我们之前的优化函数是让加噪音过程以及去噪音过程这两个更加接近,其等价于eps以及eps_theta更加接近。而这个并不够准确,我们的最终目标是要最大化Pθ(X0)P_{θ}(X_{0}),也就是X0X_{0}经过了向前加噪声以及向后去噪声的过程之后,被还原得到X0X_{0}的可能性最大,这一步是和VAE模型的目标是一致的。

对于最大化Pθ(X0)P_{θ}(X_{0}),我们将这个目标转化为最小化 log(Pθ(X0))-log(P_{θ}(X_{0}))。根据VAE模型中的推理,我们将 log(Pθ(X0))-log(P_{θ}(X_{0}))展开得到

LVLB=E[DKL(q(xTx0)pθ(xT))+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))logpθ(x0x1)]L_{VLB}=\mathbb{E}\left[D_{KL}\left(q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right) | p_{\theta}\left(\mathbf{x}_{T}\right)\right)+\sum_{t=2}^{T} D_{KL}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) | p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right)-\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)\right]

上述公式中,总共为三项,第一项与Unet中的可学习参数θ无关可以被忽略,剩余两项分别为:

  • DKL(q(xt1xt,x0)pθ(xt1xt))D_{KL}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)|| p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right) 表示的是最大化每一个去噪声操作和加噪声逆操作的相似度,其中p代表为去噪声过程,q代表为及噪声过程。
  • logP0(x0x1)-logP_{0}(x_{0}|x_{1})代表已知x1x_{1}让复原得到x0x_{0}的概率更高。
    随后将上述两项进行简化得到这两项都正比于 ϵtϵ(xt,t)2\left\Vert\epsilon_{t} - \epsilon({x_{t,t}})\right\Vert^2,这一步详细过程见link
    因此我们可以将最大化Pθ(X0)P_{θ}(X_{0})与让预测噪声与添加噪声更相近等价到一起,从而推导出了loss = nn.MSELoss(eps_theta,eps)可以作为最终的损失函数。