安装
调包之前确认你已经安装了相应的库,需要pytorch、matplotlib。
然后再安装diffusers
pip install -q diffusers
数据
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
如果你用GPU的话现在这里输出应该是:
Using device: cuda
这里就用最简单mnist数据集,当然如果你想换别的数据集自行更换就OK。
pytorch传统艺能 用DataLoader加载数据
dataset = torchvision.datasets.MNIST(root="mnist/",
train=True,
download=True,
transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# pytorch传统艺能 用DataLoader加载数据
我们使用DataLoader()读取数据后,用next(iter(data_iter))来返回批量数据,而不能使用 next(data_iter),原理就在这儿。
使用迭代器来返回批量数据,可在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足的问题。
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
可以看到输出
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])
每张图像都是一个28x28像素的灰度手写数字,像素值范围从0到1。我们上边设置一个batch_size大小为8,所以这里一组图出来8个。
注意:因为这batch取的是随机的,所以你每次运行显示的八张图是不一样的,和我结果不一样没关系。
加噪过程
假设你还没有阅读过任何扩散模型论文,现在告诉你,扩散模型的前向过程需要给图片加噪声。现在给你提供一个简单的加噪方式:
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果amount = 0,则返回输入图像不做任何更改。如果amount = 1,则返回纯噪声。
通过这种方式混合输入和噪声,我们可以保持输出在相同的范围内(0到1)。并且这样比较容易实现。写代码时候要注意Tensor的形状,以免pytorch的广播机制被破坏。
def corrupt(x, amount):
# 根据输入的amount 对 图像加噪
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1)
# 使用.view方法修改形状
return x*(1-amount) + noise*amount
让我们看看这个加噪代码的效果如何:
# 显示一下输入图像
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# 为图像添加噪声
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# 画出添加噪声之后的图像
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');
torch.linspace(0, 1, x.shape[0]) 就是生成0-1之间的数,生成数量为一个batch,也就是生成一个等差数列作为噪声添加给一个batch的图像。
模型
我们希望我们的模型可以接受一个28*28像素大小的带噪声的输入,并输出一个相同形状的去噪结果。在常规的扩散模型中使用的是U-net。不了解的可以看这个文章:浅谈语义分割网络U-Net。该模型最初是为医学图像分割任务而发明的,一个U-Net由一个“压缩路径”和一个“扩展路径”组成,数据通过“压缩路径”被压缩,然后通过“扩展路径”恢复到原始尺寸(类似于自动编码器),但还包括skip connection,允许数据在不同层次上传递信息和梯度。
一些U-Net在每个阶段都有复杂的组成模块,但我们今天只是简单实现一个扩散模型,所以也不搞什么复杂的U-net结构 了,我们在这里构建一个最最简单的U-net示例,模型可以接收一个单通道图像,并通过压缩路径上的三个卷积层(图表和代码中的down_layers)和扩展路径上的三个卷积层,在下行和上行层之间使用跳过连接。模型中使用最大池化进行下采样,使用nn.Upsample进行上采样。下图是大致的架构,显示每个层输出的通道数:
代码里用的激活函数是torch.nn.SiLU():
在压缩路径中,数据通过三个卷积层(存储在self.down_layers中)和最大池化层进行下采样。在每层之后都应用激活函数(存储在self.act中)。对于前两个压缩层,它们的输出还被存储在h列表中,以便在扩展路径中使用它们进行skip connection。在扩展路径中,数据通过三个卷积层(存储在self.up_layers中)进行上采样,并执行skip connection。在每层之后,也要应用激活函数。注意,在第一个上行层之前没有跳跃连接。
class BasicUNet(nn.Module):
# 简易版的U-net
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
# 激活函数
self.act = nn.SiLU()
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x))
if i < 2:
h.append(x)
x = self.downscale(x)
for i, l in enumerate(self.up_layers):
if i > 0:
x = self.upscale(x)
x += h.pop()
x = self.act(l(x))
return x
验证一下输入输出是否保持同维度。
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
这里我们可以看到输出是:
torch.Size([8, 1, 28, 28])
训练
说一下我们训练阶段的设定。先想一下我们这个模型要做什么:
丢给它带噪声的图片,模型应该生成其去噪结果
所以给定带噪声的noisy_x,模型要努力去恢复x。
这里我们使用均方误差计算模型输出和原始图像的差异。
接下来就是训练模型部分的代码:
-
拿到一个batch的数据
-
破坏数据模拟前向前向加噪
-
将破坏后的数据放入模型中
-
将模型输出结果和原始图像进行比较并计算loss
-
更新模型参数
# Dataloader 加载数据
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 我们要训练多少轮
n_epochs = 10
# 创建网络,将模型丢到GPU上(如果有GPU的话)
net = BasicUNet()
net.to(device)
# 损失函数是用的MSE loss
loss_fn = nn.MSELoss()
# 优化器使用的Adam
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# 记录损失
losses = []
# 训练循环
for epoch in range(n_epochs):
for x, y in train_dataloader:
# 准备好输入数据和加噪数据
# 把数据放到GPU上(如果你有的话)
x = x.to(device)
#设定随机噪声
noise_amount = torch.rand(x.shape[0]).to(device)
# 处理x,获得加噪之后的样本noisy_x
noisy_x = corrupt(x, noise_amount)
# 获取模型输出结果
pred = net(noisy_x)
# 计算loss
loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
# 反向传播更新模型参数
opt.zero_grad()
loss.backward()
opt.step()
# 存储loss记录
losses.append(loss.item())
# 输出每轮训练的loss的平均值
avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
# 画一下loss
plt.plot(losses)
plt.ylim(0, 0.1);
可以看到输出结果如下:
Finished epoch 0. Average loss for this epoch: 0.025983
Finished epoch 1. Average loss for this epoch: 0.020247
Finished epoch 2. Average loss for this epoch: 0.018660
Finished epoch 3. Average loss for this epoch: 0.017662
Finished epoch 4. Average loss for this epoch: 0.016999
Finished epoch 5. Average loss for this epoch: 0.016730
Finished epoch 6. Average loss for this epoch: 0.016610
Finished epoch 7. Average loss for this epoch: 0.016287
Finished epoch 8. Average loss for this epoch: 0.016084
Finished epoch 9. Average loss for this epoch: 0.015731
模型训练咋样了?
带兄弟们看一下效果嗷:
# 取一组图像
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting
# 用我们前边给八张图加噪的那个方法,看看模型对不同程度的噪声的回复情况
amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)
# 获取模型结构
with torch.no_grad():
preds = net(noised_x.to(device)).detach().cpu()
# 画出结果来
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');
下边我放了是三个运行结果,可以看到倒数第二项开始,恢复结果就不咋地了。但是作为一个简单模型,能获得这样的效果已经狠不戳了,随着我们逐渐优化模型,会获得更好的效果的!
实例1
实例2
实例3
本文正在参加「金石计划」