从零实现DDPM-(MNIST版)

136 阅读1分钟

理论推导在这里

本文将会分为两个部分介绍如何使用DDPM生成MNIST手写数字

  1. 训练过程
  2. 采样过程

1. 训练过程

1.1 前向扩散

首先回顾一下什么前向扩散,前向扩散是由目标分布转换为已知分布(高斯分布)的过程,用公式描述就是

xt=αtˉx0+1αtˉϵ(1)x_t = \sqrt{\bar{\alpha_t}}x_0+\sqrt{1-\bar{\alpha_t}}\epsilon \tag{1}

其中x0x_0是目标分布,xtx_t是高斯分布,αt=1βt\alpha_t=1-\beta_tϵN(0,1)\epsilon \sim N(0,1)

βt\beta_t是随tt增大而增大的数,代表着xtx_t中高斯噪声所占比重或者说原始图像被破坏的程度。αˉt=i=0tαt\bar{\alpha}_t=\prod \limits_{i=0}^t\alpha_t

αt\alpha_tβt\beta_t代码如下

beta_schedule = torch.linspace(1e-4, 0.02, T + 1, device=device)
alpha_t_schedule = 1 - beta_schedule
sqrt_minus_bar_alpha_t_schedule = torch.sqrt(1 - bar_alpha_t_schedule)

前向扩散过程代码如下

noise = torch.randn_like(imgs, device=self.device)
noise_imgs = self.sqrt_bar_alpha_t_schedule[t].view((batch_size, 1, 1 ,1)) * imgs \
    + self.sqrt_minus_bar_alpha_t_schedule[t].view((batch_size, 1, 1, 1)) * noise

1.2 损失函数

损失函数的公式为

L=ϵϵθ(αˉtx0+1αˉt,t)2(2)L=||\epsilon-\epsilon_\theta(\sqrt{\bar{\alpha}}_tx_0+\sqrt{1-\bar{\alpha}_t},t)||^2 \tag{2}

ϵθ\epsilon_\theta是一个带有self-attention的UNet,用于估计当前时间步tt时图像包含的高斯噪声。

代码为

pred_noise = self.eps_model(noise_imgs, t.unsqueeze(1))
# calculate of Loss simple ||noise - pred_noise||^2, which is MSELoss
self.criterion(pred_noise, noise)

1.3 Time embedding

跟常规UNet不同的是,DDPM的UNet需要两个输入,一个是带有噪声的图像xtx_t,一个是当前时间步tt。那么一个整数怎么和一个图像相结合呢?

这里用到了位置编码(position embedding)技术。简单来说就是将一个整数通过某种算法变成向量,然后将向量用加法的形式加到UNet的中间层输出里。

有两种方法将tt转换成向量,一种是通过学习的方式,比如一个MLP,将数字直接映射为一个向量。第二种是通过编码的方式将数字映射为一个向量。下面介绍第二种方法。

对于t的编码为

t(PE,2i)=sin(t100002id)t(PE,2i+1)=cos(t100002id)(3)t_{(PE,2i)} = sin(\frac{t}{10000^{\frac{2i}{d}}})\\ t_{(PE,2i+1)} = cos(\frac{t}{10000^{\frac{2i}{d}}})\tag{3}

tPEt_{PE}是编码后的向量,长度为d。对于tPEt_{PE}的偶数位元素,使用正弦编码,奇数位元素,使用余弦编码。

class PositionalEmbedding(nn.Module):
    def __init__(self, T: int, output_dim: int) -> None:
        super().__init__()
        self.output_dim = output_dim
        position = torch.arange(T).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, output_dim, 2) * (-math.log(10000.0) / output_dim))
        pe = torch.zeros(T, output_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        return self.pe[x].reshape(x.shape[0], self.output_dim)

代码中为了避免数字太大溢出,使用了对数变换。

大致实现如下,中间省略了一些无关的代码。

self.time_embedding = nn.Sequential(
    PositionalEmbedding(T=T, output_dim=hid_size),
    nn.Linear(hid_size, time_emb_dim),
    nn.ReLU(),
    nn.Linear(time_emb_dim, time_emb_dim)
)
···
h = self.conv_1(x)
t = self.time_emb(t)
return self.conv_3(x) + self.conv_2(h + t)

2. 反向生成

x0=xt1αtˉϵαˉt(4) x_0 = \frac{x_t-\sqrt{1-\bar{\alpha_t}}\epsilon}{\sqrt{\bar{\alpha}_t}} \tag{4}
p(xt1xt,x0)N(αt(1αˉt1)1αtˉxt+αˉt1βt1αtˉx0,1αˉt11αˉtβt)(5)p(x_{t-1}|x_t,x_0) \sim N(\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_t}}x_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1\bar{\alpha_t}}x_0,\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t) \tag{5}

代码实现如下

def sample(self, n_samples, size):
    self.eval()
    with torch.no_grad():
        # get normal noise
        x_t = torch.randn(n_samples, *size, device=self.device)
        # calculate x_(t-1) on every iteration
        for t in range(self.T, 0, -1):
            t_tensor = torch.tensor([t], device=self.device).repeat(x_t.shape[0], 1)
            # get predicted noise from model
            pred_noise = self.eps_model(x_t, t_tensor)

            # get some noise to calculate x_(t-1) as in formula (How to get a Noise)
            # for t = 0, noise should be 0
            z = torch.randn_like(x_t, device=self.device) if t > 0 else 0

            # Formula from How to get sample
            # x_(t-1) = 1 / sqrt(alpha_t) * (x_t - pred_noise * (1 - alpha_t) / sqrt(1 - alpha_t_bar)) + beta_t * eps
            x_t = 1 / torch.sqrt(self.alpha_t_schedule[t]) * \
                (x_t - pred_noise * (1 - self.alpha_t_schedule[t]) / self.sqrt_minus_bar_alpha_t_schedule[t]) + \
                torch.sqrt(self.beta_schedule[t]) * z
        return x_t

这里跟公式(5)有一点不一样的是高斯分布的方差取得是βt\sqrt{\beta_t}而不是1αˉt11αˉtβt\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t 。论文有讨论过这两种方差都能得到较好的效果,可能是模型对方差选择不敏感。

3. 实验结果

3.1 Loss 曲线

image.png

感觉损失下降的过快了,有没训练好的感觉。

3.2 生成图像

image.png

大部分还是挺好的,但还是有一些奇怪的形状。

完整代码

ddpm_mnist.ipynb

Reference

stable-diffusion-from-scratch