CartoonGAN论文复现:如何将图像动漫化

226 阅读1分钟

本文分享自华为云社区《cartoongan 图像动漫化》,作者: HWCloudAI 。

本案例是 CartoonGAN: Generative Adversarial Networks for Photo Cartoonization
的论文复习案例

拷贝数据之后,将你想动漫化的图像放到cartoongan-pytorch/test_img/文件夹下,运行后面代码即可

可以切换不同生成风格,Hosoda/Shinkai/Paprika/Hayao

参考:github.com/venture-ani…

拷贝代码和数据

import moxing as mox
mox.file.copy_parallel('obs://obs-aigallery-zc/clf/code/cartoongan-pytorch','cartoongan-pytorch')



%cd cartoongan-pytorch

运行代码

import torch
import os
import numpy as np
import torchvision.utils as vutils

from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable

import matplotlib.pyplot as plt

from network.Transformer import Transformer
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", default="test_img")
parser.add_argument("--load_size", default=1280)
parser.add_argument("--model_path", default="./pretrained_model")
parser.add_argument("--style", default="Hosoda")  # 在这里切换风格, Hosoda/Shinkai/Paprika/Hayao
parser.add_argument("--output_dir", default="test_output")
parser.add_argument("--gpu", type=int, default=0)

# opt = parser.parse_args()
opt, unknown = parser.parse_known_args()
valid_ext = [".jpg", ".png", ".jpeg"]

# setup
if not os.path.exists(opt.input_dir):
    os.makedirs(opt.input_dir)
if not os.path.exists(opt.output_dir):
    os.makedirs(opt.output_dir)

# load pretrained model
model = Transformer()
model.load_state_dict(
    torch.load(os.path.join(opt.model_path, opt.style + "_net_G_float.pth"))
)
model.eval()

disable_gpu = opt.gpu == -1 or not torch.cuda.is_available()

if disable_gpu:
    print("CPU mode")
    model.float()
else:
    print("GPU mode")
    model.cuda()

for i,files in enumerate(os.listdir(opt.input_dir)):
    ext = os.path.splitext(files)[1]
    if ext not in valid_ext:
        continue
    # load image
    input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
    input_image = np.asarray(input_image)
    # RGB -> BGR
    input_image = input_image[:, :, [2, 1, 0]]
    input_image = transforms.ToTensor()(input_image).unsqueeze(0)
    # preprocess, (-1, 1)
    input_image = -1 + 2 * input_image
    if disable_gpu:
        input_image = Variable(input_image).float()
    else:
        input_image = Variable(input_image).cuda()

    # forward
    output_image = model(input_image)
    output_image = output_image[0]
    # BGR -> RGB
    output_image = output_image[[2, 1, 0], :, :]
    output_image = output_image.data.cpu().float() * 0.5 + 0.5
    # save
    vutils.save_image(
        output_image,
        os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg"),
    )
    
    original = np.array(Image.open(os.path.join(opt.input_dir, files)))
    style = np.array(Image.open(os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg")))
    
    plt.figure(figsize=(20,20)) # 显示缩放比例
    plt.subplot(i+1,2,1)
    plt.imshow(original)
    plt.subplot(i+1,2,2)
    plt.imshow(style)
    plt.show()

print("Done!")

点击关注,第一时间了解华为云新鲜技术~