从零实现一个小小小的扩散模型

2,487 阅读7分钟

安装

调包之前确认你已经安装了相应的库,需要pytorch、matplotlib。

然后再安装diffusers

pip install -q diffusers

数据

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

如果你用GPU的话现在这里输出应该是:

Using device: cuda

这里就用最简单mnist数据集,当然如果你想换别的数据集自行更换就OK。

pytorch传统艺能 用DataLoader加载数据

dataset = torchvision.datasets.MNIST(root="mnist/", 
                                     train=True, 
                                     download=True, 
                                     transform=torchvision.transforms.ToTensor())
                                     
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# pytorch传统艺能 用DataLoader加载数据

我们使用DataLoader()读取数据后,用next(iter(data_iter))来返回批量数据,而不能使用 next(data_iter),原理就在这儿。
使用迭代器来返回批量数据,可在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足的问题。

x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

可以看到输出

Input shape: torch.Size([8, 1, 28, 28])

Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])

image.png

每张图像都是一个28x28像素的灰度手写数字,像素值范围从0到1。我们上边设置一个batch_size大小为8,所以这里一组图出来8个。

注意:因为这batch取的是随机的,所以你每次运行显示的八张图是不一样的,和我结果不一样没关系。

加噪过程

假设你还没有阅读过任何扩散模型论文,现在告诉你,扩散模型的前向过程需要给图片加噪声。现在给你提供一个简单的加噪方式:

noise = torch.rand_like(x)

noisy_x = (1-amount)*x + amount*noise

如果amount = 0,则返回输入图像不做任何更改。如果amount = 1,则返回纯噪声。

通过这种方式混合输入和噪声,我们可以保持输出在相同的范围内(0到1)。并且这样比较容易实现。写代码时候要注意Tensor的形状,以免pytorch的广播机制被破坏。

def corrupt(x, amount):
  # 根据输入的amount 对 图像加噪
  noise = torch.rand_like(x)
  amount = amount.view(-1, 1, 1, 1) 
  # 使用.view方法修改形状
  return x*(1-amount) + noise*amount 

让我们看看这个加噪代码的效果如何:

# 显示一下输入图像
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

# 为图像添加噪声
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# 画出添加噪声之后的图像
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

torch.linspace(0, 1, x.shape[0]) 就是生成0-1之间的数,生成数量为一个batch,也就是生成一个等差数列作为噪声添加给一个batch的图像。

image.png

模型

我们希望我们的模型可以接受一个28*28像素大小的带噪声的输入,并输出一个相同形状的去噪结果。在常规的扩散模型中使用的是U-net。不了解的可以看这个文章:浅谈语义分割网络U-Net。该模型最初是为医学图像分割任务而发明的,一个U-Net由一个“压缩路径”和一个“扩展路径”组成,数据通过“压缩路径”被压缩,然后通过“扩展路径”恢复到原始尺寸(类似于自动编码器),但还包括skip connection,允许数据在不同层次上传递信息和梯度。

一些U-Net在每个阶段都有复杂的组成模块,但我们今天只是简单实现一个扩散模型,所以也不搞什么复杂的U-net结构 了,我们在这里构建一个最最简单的U-net示例,模型可以接收一个单通道图像,并通过压缩路径上的三个卷积层(图表和代码中的down_layers)和扩展路径上的三个卷积层,在下行和上行层之间使用跳过连接。模型中使用最大池化进行下采样,使用nn.Upsample进行上采样。下图是大致的架构,显示每个层输出的通道数:

image.png

代码里用的激活函数是torch.nn.SiLU()

silu(x)=xsigmoid(x)=x11+ex\operatorname{silu}(x)=x * \operatorname{sigmoid}(x)=x \frac{1}{1+e^{-x}}

在压缩路径中,数据通过三个卷积层(存储在self.down_layers中)和最大池化层进行下采样。在每层之后都应用激活函数(存储在self.act中)。对于前两个压缩层,它们的输出还被存储在h列表中,以便在扩展路径中使用它们进行skip connection。在扩展路径中,数据通过三个卷积层(存储在self.up_layers中)进行上采样,并执行skip connection。在每层之后,也要应用激活函数。注意,在第一个上行层之前没有跳跃连接。

class BasicUNet(nn.Module):
# 简易版的U-net
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([ 
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
        ])
        # 激活函数
        self.act = nn.SiLU() 
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x)) 
            if i < 2: 
              h.append(x) 
              x = self.downscale(x) 
              
        for i, l in enumerate(self.up_layers):
            if i > 0: 
              x = self.upscale(x) 
              x += h.pop() 
            x = self.act(l(x)) 
            
        return x

验证一下输入输出是否保持同维度。

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape

这里我们可以看到输出是:

torch.Size([8, 1, 28, 28])

训练

说一下我们训练阶段的设定。先想一下我们这个模型要做什么:

丢给它带噪声的图片,模型应该生成其去噪结果

所以给定带噪声的noisy_x,模型要努力去恢复x

这里我们使用均方误差计算模型输出和原始图像的差异。

接下来就是训练模型部分的代码:

  1. 拿到一个batch的数据

  2. 破坏数据模拟前向前向加噪

  3. 将破坏后的数据放入模型中

  4. 将模型输出结果和原始图像进行比较并计算loss

  5. 更新模型参数

# Dataloader  加载数据
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 我们要训练多少轮
n_epochs = 10

# 创建网络,将模型丢到GPU上(如果有GPU的话)
net = BasicUNet()
net.to(device)

# 损失函数是用的MSE loss
loss_fn = nn.MSELoss()

# 优化器使用的Adam
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

# 记录损失
losses = []

# 训练循环
for epoch in range(n_epochs):

    for x, y in train_dataloader:

        # 准备好输入数据和加噪数据
        # 把数据放到GPU上(如果你有的话)
        x = x.to(device) 
        
        #设定随机噪声
        noise_amount = torch.rand(x.shape[0]).to(device)  
        
        # 处理x,获得加噪之后的样本noisy_x
        noisy_x = corrupt(x, noise_amount) 

        # 获取模型输出结果
        pred = net(noisy_x)

        # 计算loss
        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

        # 反向传播更新模型参数
        opt.zero_grad()
        loss.backward()
        opt.step()

        # 存储loss记录
        losses.append(loss.item())

    # 输出每轮训练的loss的平均值
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# 画一下loss
plt.plot(losses)
plt.ylim(0, 0.1);

可以看到输出结果如下:

Finished epoch 0. Average loss for this epoch: 0.025983

Finished epoch 1. Average loss for this epoch: 0.020247

Finished epoch 2. Average loss for this epoch: 0.018660

Finished epoch 3. Average loss for this epoch: 0.017662

Finished epoch 4. Average loss for this epoch: 0.016999

Finished epoch 5. Average loss for this epoch: 0.016730

Finished epoch 6. Average loss for this epoch: 0.016610

Finished epoch 7. Average loss for this epoch: 0.016287

Finished epoch 8. Average loss for this epoch: 0.016084

Finished epoch 9. Average loss for this epoch: 0.015731

image.png

模型训练咋样了?

带兄弟们看一下效果嗷:

# 取一组图像
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting

# 用我们前边给八张图加噪的那个方法,看看模型对不同程度的噪声的回复情况
amount = torch.linspace(0, 1, x.shape[0]) 
noised_x = corrupt(x, amount)

# 获取模型结构
with torch.no_grad():
  preds = net(noised_x.to(device)).detach().cpu()

# 画出结果来
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

下边我放了是三个运行结果,可以看到倒数第二项开始,恢复结果就不咋地了。但是作为一个简单模型,能获得这样的效果已经狠不戳了,随着我们逐渐优化模型,会获得更好的效果的!

  • 实例1 image.png

  • 实例2 image.png

  • 实例3 image.png


本文正在参加「金石计划」