前言
很久没有弄点好玩的东西了,逛逛github,想到最近看到那种动漫化的头像很可爱,于是决定也去找找相关的项目玩玩,毕竟直接调用百度啥的api太没意思了。
正文
之前写过一些有关GAN的博客,大家应该对GAN有了基本的了解。最基础的内容就是基于零和博弈的思想使得生成器生成的假图像逼真的可以足够骗过判别器,更本质就是生成器从对抗中不断学习,学习到了真实图像中的数据分布。但是真正动手之后会发现训练GAN其实还是比较困难的,难点就在于难以收敛,而且还有模式崩塌的情况出现。
再说回这次的目的,想要实现头像动漫化,这在概念上应该是图像的风格迁移,那最基本就应该会想到内容损失和风格损失,当然对具体的问题会提出更多不同的损失。在github上搜索下就会发现这几个名字
- CartoonGAN
- AnimeGAN
- AnimeGAN2 再去看看相关的论文和github链接,会发现上面的顺序就是逐渐优化的过程,具体的论文阅读部分我会放在论文专栏,今天的注意力还是集中在实现上。我最终还是选择了SOTA模型AnimeGAN2-tensorflow
当然也有pytorch版本AnimeGAN2-pytorch
我选择的是pytorch版本,因为tensorflow是1.X版本实现,而我环境已经全部转为TF2了,而且越来越觉得TF复杂而杂乱的api很不喜欢,更多的网络实现都是基于keras,有时候又失去了TF本身的灵活性。好啦,又扯远了,现在开始正式动手实现。
开始实现
进去AnimeGAN2-pytorch,如果只需要实现的话不需要下载整个项目,只需要下载model.py
import torch
from torch import nn
import torch.nn.functional as F
class ConvNormLReLU(nn.Sequential):
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
pad_layer = {
"zero": nn.ZeroPad2d,
"same": nn.ReplicationPad2d,
"reflect": nn.ReflectionPad2d,
}
if pad_mode not in pad_layer:
raise NotImplementedError
super(ConvNormLReLU, self).__init__(
pad_layer[pad_mode](padding),
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
nn.LeakyReLU(0.2, inplace=True)
)
class InvertedResBlock(nn.Module):
def __init__(self, in_ch, out_ch, expansion_ratio=2):
super(InvertedResBlock, self).__init__()
self.use_res_connect = in_ch == out_ch
bottleneck = int(round(in_ch * expansion_ratio))
layers = []
if expansion_ratio != 1:
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
# dw
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
# pw
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
self.layers = nn.Sequential(*layers)
def forward(self, input):
out = self.layers(input)
if self.use_res_connect:
out = input + out
return out
class Generator(nn.Module):
def __init__(self, ):
super().__init__()
self.block_a = nn.Sequential(
ConvNormLReLU(3, 32, kernel_size=7, padding=3),
ConvNormLReLU(32, 64, stride=2, padding=(0, 1, 0, 1)),
ConvNormLReLU(64, 64)
)
self.block_b = nn.Sequential(
ConvNormLReLU(64, 128, stride=2, padding=(0, 1, 0, 1)),
ConvNormLReLU(128, 128)
)
self.block_c = nn.Sequential(
ConvNormLReLU(128, 128),
InvertedResBlock(128, 256, 2),
InvertedResBlock(256, 256, 2),
InvertedResBlock(256, 256, 2),
InvertedResBlock(256, 256, 2),
ConvNormLReLU(256, 128),
)
self.block_d = nn.Sequential(
ConvNormLReLU(128, 128),
ConvNormLReLU(128, 128)
)
self.block_e = nn.Sequential(
ConvNormLReLU(128, 64),
ConvNormLReLU(64, 64),
ConvNormLReLU(64, 32, kernel_size=7, padding=3)
)
self.out_layer = nn.Sequential(
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
nn.Tanh()
)
def forward(self, input, align_corners=True):
out = self.block_a(input)
half_size = out.size()[-2:]
out = self.block_b(out)
out = self.block_c(out)
if align_corners:
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_d(out)
if align_corners:
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_e(out)
out = self.out_layer(out)
return out
其中包含了生成器代码,这也是生成动漫风格图片的关键部分。 然后就是需要去对应的网盘下载不同风格下训练好的模型,可能无法访问,所以我把我下载的放到网盘分享。 模型链接 密码: n78v
然后就是自己写个生成图像的代码,当然可以根据原项目中的test_faces.ipynb来修改 下面是我的生成代码,需要提前将模型放在同一目录下,并且创建samples文件夹用来存放生成图片
import os
import cv2
import matplotlib.pyplot as plt
import torch
import random
import numpy as np
from model import Generator
def load_image(path, size=None):
image = image2tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB))
w, h = image.shape[-2:]
if w != h:
crop_size = min(w, h)
left = (w - crop_size) // 2
right = left + crop_size
top = (h - crop_size) // 2
bottom = top + crop_size
image = image[:, :, left:right, top:bottom]
if size is not None and image.shape[-1] != size:
image = torch.nn.functional.interpolate(image, (size, size), mode="bilinear", align_corners=True)
return image
def image2tensor(image):
image = torch.FloatTensor(image).permute(2, 0, 1).unsqueeze(0) / 255.
return (image - 0.5) / 0.5
def tensor2image(tensor):
tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1, 2, 0).cpu().numpy()
return tensor * 0.5 + 0.5
def imshow(img, size=5, cmap='jet'):
plt.figure(figsize=(size, size))
plt.imshow(img, cmap=cmap)
plt.axis('off')
plt.show()
if __name__ == '__main__':
device = 'cuda'
torch.set_grad_enabled(False)
image_size = 300
img=input("")
model = Generator().eval().to(device)
ckpt = torch.load(f"./new.pth", map_location=device)
model.load_state_dict(ckpt)
result=[]
image = load_image(f"./face/{img}", image_size)
output = model(image.to(device))
result.append(torch.cat([image, output.cpu()], 3))
result = torch.cat(result, 2)
imshow(tensor2image(result), 40)
cv2.imwrite(f'./samples/new+{img}', cv2.cvtColor(255 * tensor2image(result), cv2.COLOR_BGR2RGB))
在device中可以选择cpu或者cuda(使用GPU)
可能遇到的问题
当电脑中pytorch版本低于1.6时,载入模型时torch.load会报错,但是受限于显卡算力好像不支持安装高版本torch,那只能找其他解决办法。
找个笔记本,一般都可以安装高版本pytorch(大于1.6就行),配好环境之后运行下面代码, 用新生成的模型文件去载入,也就是我代码中的new.pth 修改模型文件的代码
import torch
weight = torch.load("高版本的模型地址")
torch.save(weight, '自定义新的模型地址', _use_new_zipfile_serialization=False)
效果
效果还是比较不错的,生成图片的速度也比较快,但是我感觉还是存在一些问题
- 由于代码中会对图片进行截取,所以有可能丢失原图中的信息
- 可能是训练数据集的问题,传入光照不好的图片出现的效果很不好,轮廓曲线十分模糊。