如何去除图像马赛克?(加强版)

826 阅读12分钟

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

一、前言

在之前的文章中,我们讨论过图像修复的方法。具体见:如何去除图片马赛克?

这里回顾一些关键的内容。这里我们要考虑几个问题:

  1. 能否去除马赛克?
  2. 去除马赛克实际在做什么?
  3. 如何实现去除马赛克?

首先第一个问题,严格来说是不能的,在添加马赛克时,原图会被破坏,破坏后的图片中不再包含原有信息,因此马赛克是无法去除的。

那我们说的去除马赛克是什么意思呢?实际上我们并不是去除马赛克,而是预测马赛克区域的内容。

那么如何实现去除马赛克呢?在之前的文章中,使用了一个非常简单的Unet网络,我们直接将图像添加马赛克,然后预测没有添加马赛克的图像。这种方法非常简单有效,而且方便训练。但如果想要更好的效果,需要对原有模型做一些修改,这个新的模型就是Codeformer了。

本文我们将介绍Codeformer的基本原理,并使用代码实现一个简易的Codeformer。

二、Codeformer

Codeformer架构是一种用于图像高清重构和图形修复的模型,Codeformer不只可以修复马赛克,还可以修复画笔功能造成的图像破坏,以及将低分辨率图像转换成高分辨率图像。下面我们来看看Codeformer的实现细节。

2.1 Codeformer原理

Codeformer依旧是经典的Encoder-Decoder架构,但是与之前的不同,Codeformer使用了两组Encoder-Decoder,其架构如图所示:

image.png

2.1.1 HQVAE

首先是HQVAE,HQVAE由HQ Encoder和HQ Decoder构成,我们会使用高清图像来训练高清VAE。正常的VAE用Encoder将图像编码成一个分布,然后使用这个分布采样出向量z,最后使用Decoder将z解码成原图像。

Codeformer在VAE中引入了一个Codebook。假设z是一个1024*16的特征图,Codeformer中会有一个M*16的可训练参数,然后将z中每行向量用Codeformer中与之最近的行向量替换,最后将替换后的z用于解码图像。

这个过程相当于做了一次量化操作,原本我们生成的z有无限可能,而使用Codebook转换后,z的每行只有1024种可能。以人脸重构任务为例,所有人脸的某个局部范围都是基本一致,比如眼球、鼻孔、牙齿等,而Codebook中的每行可能就代表人脸的某个局部特征。用Codebook转换z的过程就是“将像牙齿的地方,替换成牙齿通常的样子,将像鼻孔的地方,替换成像鼻孔的样子。”

不过实际上Codebook中的含义是比鼻孔、眼球这种特征更为抽象的特征,上面我们说的只是一种形象的理解。

2.1.2 重构VAE

Codeformer需要使用高清图像训练好HQVAE,然后开始训练用于图像重构和高清修复的网络。然后使用LQ Encoder、Codebook和HQ Decoder构成一个新的重构VAE。

整体流程和HQVAE的训练类似,有两处不同。第一处是从Codebook取向量时,HQVAE使用的是最近邻方法,而重构VAE中使用的是Transformer分类器预测对应的Code。第二处则是解码过程中加入了CFT模块。这里CFT模块是非必要的,因此在后面的实践中我们舍去。

在开始实现前我们先导入需要的模块:

from glob import glob

import cv2
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.utils import make_grid
from tqdm import tqdm

三、创建数据集

在这里我们有两个网络需要训练,因此我们需要创建两个数据集类。首先是HQVAE的训练,我们就是输入高清图像,然后输出高清图像,因此我们只需要把读取的图像返回两个就行。代码具体如下:

class BaseDataset(Dataset):
    def __init__(self, data_root=r"F:\celeba*", image_size=96):
        super().__init__()
        self.image_size = image_size
        # 预处理
        self.trans = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])
        self.image_paths = glob(data_root)

    def __len__(self):
        return len(self.image_paths)


class HQDataset(BaseDataset):

    def __getitem__(self, item):
        image = Image.open(self.image_paths[item])
        image = self.trans(image)
        return image, image

这里我们先实现了一个BaseDataset,我们两个数据集都会继承这个类。在BaseDataset中我们拿到了图片的所有路径、创建了一个trans用于后面处理图像,实现了len方法。

在HQDataset中,我们只需要实现getitem即可,而且在getitem中只需要读取图像,然后使用trans处理,最后x和y都返回image即可。

然后是LQDataset,这里我们还是继承BaseDataset,然后实现getitem方法。LQDataset和HQDataset的区别在于x是添加马赛克或下采样的图像,这里我们编写两个方法用于添加马赛克和降采样,首先是添加马赛克我们要做的就是选取一个区域,然后对该区域下采样然后在上采样,代码如下:

@staticmethod
def add_mosaic(image: Image.Image, mosaic_size: int, region_size: int) -> Image.Image:

    img_width, img_height = image.size

    # 随机选择一个区域的左上角坐标
    x = random.randint(0, img_width - region_size)
    y = random.randint(0, img_height - region_size)

    # 获取选中的区域
    region = image.crop((x, y, x + region_size, y + region_size))

    # 对该区域进行下采样再上采样,产生马赛克效果
    region = region.resize((mosaic_size, mosaic_size), Image.NEAREST)  # 缩小
    region = region.resize((region_size, region_size), Image.NEAREST)  # 放大

    # 将马赛克区域粘回到原图
    image.paste(region, (x, y))
    return image

然后是生成低分辨率但是尺寸不变的图像(高清修复任务),这里我们要做的就是对整个图片下采样然后上采样,代码如下:

@staticmethod
def reduce_image_quality(image: Image.Image, quality_factor: float) -> Image.Image:
    if not (0 < quality_factor <= 1):
        raise ValueError("quality_factor 应该在 (0, 1] 范围内")

    # 获取图像的原始尺寸
    original_size = image.size

    # 根据 quality_factor 计算缩小后的尺寸
    new_size = (int(original_size[0] * quality_factor), int(original_size[1] * quality_factor))

    # 缩小图像
    low_res_image = image.resize(new_size, Image.BILINEAR)  # 使用双线性插值缩小

    # 再将图像放大回原尺寸
    low_res_image = low_res_image.resize(original_size, Image.BILINEAR)  # 使用双线性插值放大

    return low_res_image

实现效果如下:

image.png

然后是getitem方法,我们要用这个数据集完成去除马赛克和高清修复两个任务,因此我们要随机对图像添加马赛克或者降低分辨率,代码如下:

class LQDataset(BaseDataset):

    def __getitem__(self, item):
        image = Image.open(self.image_paths[item])
        # 随机选择是添加马赛克还是降低分辨率
        if random.choice([True, False]):
            # 随机选择马赛克强度
            mosaic_size = random.randint(5, 25)
            region_size = random.randint(50, 150)
            x = self.add_mosaic(image.copy(), mosaic_size, region_size)
        else:
            # 随机选择分辨率降低程度
            quality_factor = random.randint(2, 10)
            x = self.reduce_image_quality(image.copy(), quality_factor)
        return self.trans(x), self.trans(image)

四、网络结构

在Codeformer中,HQVAE和LQVAE结构有所不同,但是我们的实现中采用同一结构,即包含Encoder、CodePredictor、Codebook、Decoder四个部分。我们一一来实现。

4.1 其余部件

在开始实现四个部分前,我们先实现两个通用的卷积层,代码如下:

class Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.Conv2d(cin, cout, kernel_size, stride, padding),
            nn.BatchNorm2d(cout)
        )
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)


class Conv2dTranspose(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
            nn.BatchNorm2d(cout)
        )
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)

这里我们简单把Conv2d、残差连接、BatchNorm组成卷积模块。把ConvTranspose2d和BatchNorm组成转置卷积模块。这里使用与前面数字人基本一致的结构。

4.2 Encoder

Encoder要做的就是不断对图像卷积并下采样,代码如下:

class HQEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = self.encoder = nn.ModuleList([
            nn.Sequential(
                Conv2d(3, 16, kernel_size=7, stride=1, padding=3)
            ),  # 96,96
            nn.Sequential(
                Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 48,48
                Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)
            ),
            nn.Sequential(
                Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 24,24
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)
            ),
            nn.Sequential(
                Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 12,12
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)
            ),
            nn.Sequential(
                Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # 6,6
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
            ),
            nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
            )
        ])

    def forward(self, x):
        features = []
        for block in self.model:
            x = block(x)
            features.append(x)
        return features

在正向传播中,我们保存了每个block的输出。上面的Encoder有六个block,输入(batch_size,3,96,96)的图像,各个block输出的特征图形状为:

torch.Size([16, 16, 96, 96])
torch.Size([16, 32, 48, 48])
torch.Size([16, 64, 24, 24])
torch.Size([16, 128, 12, 12])
torch.Size([16, 256, 6, 6])
torch.Size([16, 512, 6, 6])

基于上面的输出,我们构造Decoder部分。

4.3 Decoder

首先Decoder需要有和Encoder同等数量的block。第一个block输入(16, 512, 6, 6)的特征图,而后每次将上一个block的输出与Encoder每个block的输出合并作为下一个DecoderBlock的输入,代码如下:

class HQDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.ModuleList([
            nn.Sequential(
                Conv2d(512, 512, kernel_size=1, stride=1, padding=0)
            ),  # 6,6
            nn.Sequential(
                nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
                nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            ),  # 6,6
            nn.Sequential(
                Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
                Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True)
            ),  # 12, 12
            nn.Sequential(
                Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)
            ),  # 24, 24
            nn.Sequential(
                Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)
            ),  # 48, 48
            nn.Sequential(
                Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)
            )  # 96,96
        ])

    def forward(self, x, features=None):
        for block in self.model:
            x = block(x)
            x = torch.concatenate([x, features[-1]], dim=1)
            features.pop()
        return x

这里每次特征图的形状和Encoder是反过来的,唯一需要计算的是通道数。比如block2的输入通道数是1025,这是因为block1的输出通道是512,Decoder最后一个输出的通道数是512。再比如block3的输入通道数是768,这是因为block2的输出通道是512,而Decoder倒数第二的输出通道数是256。

4.4 CodePredictor

在Codeformer中CodePredictor使用的是Transformer网络,而且只在LQVAE中使用,这里我们HQVAE和LQVAE都使用同样的CodePredictor。

首先关注Encoder的输出z,其形状为(batch_size,512,6,6),这里表示z有512个code,每个code是长度为36的向量,假设我们指定codebook的大小为256,那么CodePredictor则是一个输入36维度,输出256个类别的分裂网络,这里使用一个简单的线性网络:

nn.Linear(6 * 6, n_codes)

4.5 Codebook

Codebook就是一个可训练的特征表,这里很容易联想到Embedding,在我们的实现中也确实使用Embedding完成。

Codebook的作用是替换z中的向量,因此Codebook是一个(n_codes,36)的Embedding:

self.codebook = nn.Embedding(n_codes, 36)

4.6 HQVAE

下面我们把四个模块整合,代码如下:

class HQVAE(nn.Module):
    def __init__(self, n_codes=256, code_dim=36):
        super().__init__()
        self.encoder = HQEncoder()
        self.code_predictor = nn.Linear(6 * 6, n_codes)
        self.codebook = nn.Embedding(n_codes, code_dim)
        self.decoder = HQDecoder()
        self.output_block = nn.Sequential(
            Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

上面我们把四个模块简单放在一起了,下面最关键的是forward函数。forward的整体流程如下:

  1. Encoder编码图像成z
  2. 使用codePredictor预测Code
  3. 使用Codebook量化z
  4. 使用量化后的z解码出图像

下面是具体实现:

def forward(self, x):
    batch_size = x.size(0)
    features = self.encoder(x)
    # 解码出z
    z = features[-1]
    idxes = torch.LongTensor(batch_size, 512).to(x.device)
    # 预测出所有code
    for i in range(z.size(1)):
        probs = F.softmax(self.code_predictor(z[:, i].view(batch_size, -1)), dim=1)
        idx = probs.argmax(dim=1)
        idxes[:, i] = idx
    # 用codebook量化z
    quantized_z = self.codebook(idxes).view(batch_size, 512, 6, 6)
    # 使用量化后的z解码出图像
    return self.output_block(self.decoder(quantized_z, features))

下面我们就可以开始训练了。

五、训练

首先是训练HQVAE,这里就是常规的pytorch训练代码,具体如下:

def train_hq():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HQVAE().to(device)
    hq_dataset = HQDataset(data_root=r'F:\img_align_celeba*')
    hq_dataloader = DataLoader(hq_dataset, batch_size=64)
    loss_fn = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    global_step = 0
    for epoch in range(10):
        for x, y in tqdm(hq_dataloader, total=len(hq_dataloader)):
            x = x.to(device)
            y = y.to(device)
            predictions = model(x)
            loss = loss_fn(y, predictions)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step == 0 or global_step % 50 == 0:
                print(f'epoch: {epoch}, loss: {loss.item()}')
            if global_step % 300 == 0:
                outputs = make_grid(predictions, 8)
                img = outputs.cpu().numpy().transpose(1, 2, 0)
                img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                Image.fromarray(img).show()
                torch.save(model.state_dict(), 'lqvae.pth')
            global_step += 1
    torch.save(model.state_dict(), 'hqvae.pth')

首先我们使用上述代码训练好HQVAE,然后再加载HQVAE。在训练LQVAE时,我们会使用HQVAE的Codebook、Decoder、output_block,并冻结这些模块的参数,其余模块则用来训练,代码如下:

def train_lq():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HQVAE()
    # 加载除Encoder和CodePredictor外的模块
    new_state_dict = model.state_dict()
    state_dict = torch.load('hqvae.pth')
    for k, v in state_dict.items():
        if k.__contains__('encoder') or k.__contains__('code_predictor'):
            continue
        new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    # 将Codebook、Decoder、output_block设置为不计算梯度
    model.codebook.requires_grad = False
    model.decoder.requires_grad = False
    # model.output_block.requires_grad = False
    model = model.to(device)

    lq_dataset = LQDataset(r'F:\img_align_celeba*')
    lq_dataloader = DataLoader(lq_dataset, batch_size=64)

    loss_fn = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    global_step = 0
    for epoch in range(10):
        for x, y in tqdm(lq_dataloader, total=len(lq_dataloader)):
            x = x.to(device)
            y = y.to(device)
            predictions = model(x)
            loss = loss_fn(y, predictions)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step == 0 or global_step % 50 == 0:
                print(f'epoch: {epoch}, loss: {loss.item()}')
            if global_step % 300 == 0:
                torch.save(model.state_dict(), 'lqvae.pth')

            global_step += 1
    torch.save(model.state_dict(), 'lqvae.pth')

最后我们使用训练好的LQVAE作为去除马赛克和重构网络。重构效果如图所示:

image.png

其中第一行是原图,第二行是重构图像,第三行则是低分辨率或马赛克图像。可以看到效果还是不错的。