本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
一、前言
在之前的文章中,我们讨论过图像修复的方法。具体见:如何去除图片马赛克?
这里回顾一些关键的内容。这里我们要考虑几个问题:
- 能否去除马赛克?
- 去除马赛克实际在做什么?
- 如何实现去除马赛克?
首先第一个问题,严格来说是不能的,在添加马赛克时,原图会被破坏,破坏后的图片中不再包含原有信息,因此马赛克是无法去除的。
那我们说的去除马赛克是什么意思呢?实际上我们并不是去除马赛克,而是预测马赛克区域的内容。
那么如何实现去除马赛克呢?在之前的文章中,使用了一个非常简单的Unet网络,我们直接将图像添加马赛克,然后预测没有添加马赛克的图像。这种方法非常简单有效,而且方便训练。但如果想要更好的效果,需要对原有模型做一些修改,这个新的模型就是Codeformer了。
本文我们将介绍Codeformer的基本原理,并使用代码实现一个简易的Codeformer。
二、Codeformer
Codeformer架构是一种用于图像高清重构和图形修复的模型,Codeformer不只可以修复马赛克,还可以修复画笔功能造成的图像破坏,以及将低分辨率图像转换成高分辨率图像。下面我们来看看Codeformer的实现细节。
2.1 Codeformer原理
Codeformer依旧是经典的Encoder-Decoder架构,但是与之前的不同,Codeformer使用了两组Encoder-Decoder,其架构如图所示:
2.1.1 HQVAE
首先是HQVAE,HQVAE由HQ Encoder和HQ Decoder构成,我们会使用高清图像来训练高清VAE。正常的VAE用Encoder将图像编码成一个分布,然后使用这个分布采样出向量z,最后使用Decoder将z解码成原图像。
Codeformer在VAE中引入了一个Codebook。假设z是一个1024*16的特征图,Codeformer中会有一个M*16的可训练参数,然后将z中每行向量用Codeformer中与之最近的行向量替换,最后将替换后的z用于解码图像。
这个过程相当于做了一次量化操作,原本我们生成的z有无限可能,而使用Codebook转换后,z的每行只有1024种可能。以人脸重构任务为例,所有人脸的某个局部范围都是基本一致,比如眼球、鼻孔、牙齿等,而Codebook中的每行可能就代表人脸的某个局部特征。用Codebook转换z的过程就是“将像牙齿的地方,替换成牙齿通常的样子,将像鼻孔的地方,替换成像鼻孔的样子。”
不过实际上Codebook中的含义是比鼻孔、眼球这种特征更为抽象的特征,上面我们说的只是一种形象的理解。
2.1.2 重构VAE
Codeformer需要使用高清图像训练好HQVAE,然后开始训练用于图像重构和高清修复的网络。然后使用LQ Encoder、Codebook和HQ Decoder构成一个新的重构VAE。
整体流程和HQVAE的训练类似,有两处不同。第一处是从Codebook取向量时,HQVAE使用的是最近邻方法,而重构VAE中使用的是Transformer分类器预测对应的Code。第二处则是解码过程中加入了CFT模块。这里CFT模块是非必要的,因此在后面的实践中我们舍去。
在开始实现前我们先导入需要的模块:
from glob import glob
import cv2
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.utils import make_grid
from tqdm import tqdm
三、创建数据集
在这里我们有两个网络需要训练,因此我们需要创建两个数据集类。首先是HQVAE的训练,我们就是输入高清图像,然后输出高清图像,因此我们只需要把读取的图像返回两个就行。代码具体如下:
class BaseDataset(Dataset):
def __init__(self, data_root=r"F:\celeba*", image_size=96):
super().__init__()
self.image_size = image_size
# 预处理
self.trans = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
])
self.image_paths = glob(data_root)
def __len__(self):
return len(self.image_paths)
class HQDataset(BaseDataset):
def __getitem__(self, item):
image = Image.open(self.image_paths[item])
image = self.trans(image)
return image, image
这里我们先实现了一个BaseDataset,我们两个数据集都会继承这个类。在BaseDataset中我们拿到了图片的所有路径、创建了一个trans用于后面处理图像,实现了len方法。
在HQDataset中,我们只需要实现getitem即可,而且在getitem中只需要读取图像,然后使用trans处理,最后x和y都返回image即可。
然后是LQDataset,这里我们还是继承BaseDataset,然后实现getitem方法。LQDataset和HQDataset的区别在于x是添加马赛克或下采样的图像,这里我们编写两个方法用于添加马赛克和降采样,首先是添加马赛克我们要做的就是选取一个区域,然后对该区域下采样然后在上采样,代码如下:
@staticmethod
def add_mosaic(image: Image.Image, mosaic_size: int, region_size: int) -> Image.Image:
img_width, img_height = image.size
# 随机选择一个区域的左上角坐标
x = random.randint(0, img_width - region_size)
y = random.randint(0, img_height - region_size)
# 获取选中的区域
region = image.crop((x, y, x + region_size, y + region_size))
# 对该区域进行下采样再上采样,产生马赛克效果
region = region.resize((mosaic_size, mosaic_size), Image.NEAREST) # 缩小
region = region.resize((region_size, region_size), Image.NEAREST) # 放大
# 将马赛克区域粘回到原图
image.paste(region, (x, y))
return image
然后是生成低分辨率但是尺寸不变的图像(高清修复任务),这里我们要做的就是对整个图片下采样然后上采样,代码如下:
@staticmethod
def reduce_image_quality(image: Image.Image, quality_factor: float) -> Image.Image:
if not (0 < quality_factor <= 1):
raise ValueError("quality_factor 应该在 (0, 1] 范围内")
# 获取图像的原始尺寸
original_size = image.size
# 根据 quality_factor 计算缩小后的尺寸
new_size = (int(original_size[0] * quality_factor), int(original_size[1] * quality_factor))
# 缩小图像
low_res_image = image.resize(new_size, Image.BILINEAR) # 使用双线性插值缩小
# 再将图像放大回原尺寸
low_res_image = low_res_image.resize(original_size, Image.BILINEAR) # 使用双线性插值放大
return low_res_image
实现效果如下:
然后是getitem方法,我们要用这个数据集完成去除马赛克和高清修复两个任务,因此我们要随机对图像添加马赛克或者降低分辨率,代码如下:
class LQDataset(BaseDataset):
def __getitem__(self, item):
image = Image.open(self.image_paths[item])
# 随机选择是添加马赛克还是降低分辨率
if random.choice([True, False]):
# 随机选择马赛克强度
mosaic_size = random.randint(5, 25)
region_size = random.randint(50, 150)
x = self.add_mosaic(image.copy(), mosaic_size, region_size)
else:
# 随机选择分辨率降低程度
quality_factor = random.randint(2, 10)
x = self.reduce_image_quality(image.copy(), quality_factor)
return self.trans(x), self.trans(image)
四、网络结构
在Codeformer中,HQVAE和LQVAE结构有所不同,但是我们的实现中采用同一结构,即包含Encoder、CodePredictor、Codebook、Decoder四个部分。我们一一来实现。
4.1 其余部件
在开始实现四个部分前,我们先实现两个通用的卷积层,代码如下:
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
这里我们简单把Conv2d、残差连接、BatchNorm组成卷积模块。把ConvTranspose2d和BatchNorm组成转置卷积模块。这里使用与前面数字人基本一致的结构。
4.2 Encoder
Encoder要做的就是不断对图像卷积并下采样,代码如下:
class HQEncoder(nn.Module):
def __init__(self):
super().__init__()
self.model = self.encoder = nn.ModuleList([
nn.Sequential(
Conv2d(3, 16, kernel_size=7, stride=1, padding=3)
), # 96,96
nn.Sequential(
Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)
),
nn.Sequential(
Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)
),
nn.Sequential(
Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)
),
nn.Sequential(
Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
),
nn.Sequential(
nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
)
])
def forward(self, x):
features = []
for block in self.model:
x = block(x)
features.append(x)
return features
在正向传播中,我们保存了每个block的输出。上面的Encoder有六个block,输入(batch_size,3,96,96)的图像,各个block输出的特征图形状为:
torch.Size([16, 16, 96, 96])
torch.Size([16, 32, 48, 48])
torch.Size([16, 64, 24, 24])
torch.Size([16, 128, 12, 12])
torch.Size([16, 256, 6, 6])
torch.Size([16, 512, 6, 6])
基于上面的输出,我们构造Decoder部分。
4.3 Decoder
首先Decoder需要有和Encoder同等数量的block。第一个block输入(16, 512, 6, 6)的特征图,而后每次将上一个block的输出与Encoder每个block的输出合并作为下一个DecoderBlock的输入,代码如下:
class HQDecoder(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.ModuleList([
nn.Sequential(
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)
), # 6,6
nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
), # 6,6
nn.Sequential(
Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True)
), # 12, 12
nn.Sequential(
Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)
), # 24, 24
nn.Sequential(
Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)
), # 48, 48
nn.Sequential(
Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)
) # 96,96
])
def forward(self, x, features=None):
for block in self.model:
x = block(x)
x = torch.concatenate([x, features[-1]], dim=1)
features.pop()
return x
这里每次特征图的形状和Encoder是反过来的,唯一需要计算的是通道数。比如block2的输入通道数是1025,这是因为block1的输出通道是512,Decoder最后一个输出的通道数是512。再比如block3的输入通道数是768,这是因为block2的输出通道是512,而Decoder倒数第二的输出通道数是256。
4.4 CodePredictor
在Codeformer中CodePredictor使用的是Transformer网络,而且只在LQVAE中使用,这里我们HQVAE和LQVAE都使用同样的CodePredictor。
首先关注Encoder的输出z,其形状为(batch_size,512,6,6),这里表示z有512个code,每个code是长度为36的向量,假设我们指定codebook的大小为256,那么CodePredictor则是一个输入36维度,输出256个类别的分裂网络,这里使用一个简单的线性网络:
nn.Linear(6 * 6, n_codes)
4.5 Codebook
Codebook就是一个可训练的特征表,这里很容易联想到Embedding,在我们的实现中也确实使用Embedding完成。
Codebook的作用是替换z中的向量,因此Codebook是一个(n_codes,36)的Embedding:
self.codebook = nn.Embedding(n_codes, 36)
4.6 HQVAE
下面我们把四个模块整合,代码如下:
class HQVAE(nn.Module):
def __init__(self, n_codes=256, code_dim=36):
super().__init__()
self.encoder = HQEncoder()
self.code_predictor = nn.Linear(6 * 6, n_codes)
self.codebook = nn.Embedding(n_codes, code_dim)
self.decoder = HQDecoder()
self.output_block = nn.Sequential(
Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()
)
上面我们把四个模块简单放在一起了,下面最关键的是forward函数。forward的整体流程如下:
- Encoder编码图像成z
- 使用codePredictor预测Code
- 使用Codebook量化z
- 使用量化后的z解码出图像
下面是具体实现:
def forward(self, x):
batch_size = x.size(0)
features = self.encoder(x)
# 解码出z
z = features[-1]
idxes = torch.LongTensor(batch_size, 512).to(x.device)
# 预测出所有code
for i in range(z.size(1)):
probs = F.softmax(self.code_predictor(z[:, i].view(batch_size, -1)), dim=1)
idx = probs.argmax(dim=1)
idxes[:, i] = idx
# 用codebook量化z
quantized_z = self.codebook(idxes).view(batch_size, 512, 6, 6)
# 使用量化后的z解码出图像
return self.output_block(self.decoder(quantized_z, features))
下面我们就可以开始训练了。
五、训练
首先是训练HQVAE,这里就是常规的pytorch训练代码,具体如下:
def train_hq():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HQVAE().to(device)
hq_dataset = HQDataset(data_root=r'F:\img_align_celeba*')
hq_dataloader = DataLoader(hq_dataset, batch_size=64)
loss_fn = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
global_step = 0
for epoch in range(10):
for x, y in tqdm(hq_dataloader, total=len(hq_dataloader)):
x = x.to(device)
y = y.to(device)
predictions = model(x)
loss = loss_fn(y, predictions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step == 0 or global_step % 50 == 0:
print(f'epoch: {epoch}, loss: {loss.item()}')
if global_step % 300 == 0:
outputs = make_grid(predictions, 8)
img = outputs.cpu().numpy().transpose(1, 2, 0)
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
Image.fromarray(img).show()
torch.save(model.state_dict(), 'lqvae.pth')
global_step += 1
torch.save(model.state_dict(), 'hqvae.pth')
首先我们使用上述代码训练好HQVAE,然后再加载HQVAE。在训练LQVAE时,我们会使用HQVAE的Codebook、Decoder、output_block,并冻结这些模块的参数,其余模块则用来训练,代码如下:
def train_lq():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HQVAE()
# 加载除Encoder和CodePredictor外的模块
new_state_dict = model.state_dict()
state_dict = torch.load('hqvae.pth')
for k, v in state_dict.items():
if k.__contains__('encoder') or k.__contains__('code_predictor'):
continue
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# 将Codebook、Decoder、output_block设置为不计算梯度
model.codebook.requires_grad = False
model.decoder.requires_grad = False
# model.output_block.requires_grad = False
model = model.to(device)
lq_dataset = LQDataset(r'F:\img_align_celeba*')
lq_dataloader = DataLoader(lq_dataset, batch_size=64)
loss_fn = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
global_step = 0
for epoch in range(10):
for x, y in tqdm(lq_dataloader, total=len(lq_dataloader)):
x = x.to(device)
y = y.to(device)
predictions = model(x)
loss = loss_fn(y, predictions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step == 0 or global_step % 50 == 0:
print(f'epoch: {epoch}, loss: {loss.item()}')
if global_step % 300 == 0:
torch.save(model.state_dict(), 'lqvae.pth')
global_step += 1
torch.save(model.state_dict(), 'lqvae.pth')
最后我们使用训练好的LQVAE作为去除马赛克和重构网络。重构效果如图所示:
其中第一行是原图,第二行是重构图像,第三行则是低分辨率或马赛克图像。可以看到效果还是不错的。