AI风格迁移(p2gan)的使用以及工程化

314 阅读2分钟

背景

本文主要介绍风格迁移算法库GitHub - i-evi/p2gan: Patch Permutation GAN(P²-GAN)的使用以及转换成ONNX的过程中遇到的坑,下面是转换成onnx后,使用ONNXRuntime + OpenCV渲染的结果,这里的模型是基于梵高的画训练出来的 image.png

P2GAN的基本使用方式

P2GAN文档中对于训练和处理图片介绍的比较详细

训练

python train.py --model ./mymodels/vangop --style mystyle/vango_portrait.jpg --dataset H:\AIImageDatasets\VOCtrainval_06-Nov-2007\VOCdevkit\VOC2007\JPEGImages --lambda 5e-6

训练集使用的是VOC2007中的图片。风格图片你只需要指定一张,比如梵高的画作即可,如果你想训练出其他风格的模型,可以自行选择其他风格图片

Van-Gogh-The-Starry-Night.jpg

图片转换

通过下面的命令,可以将./tests/文件夹中的图片全部风格化,并输出到output文件夹中

python render.py --model ./custom/models/0 --inp ./tests/ --oup output --size 1024

P2GAN工程化

模型保存为PB格式


import tensorflow as tf
import numpy as np

from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.tools.graph_transforms import TransformGraph

import numpy as np
import tensorflow as tf
import util
import model

imbatch_read     = util.imbatch_read
img_write        = util.img_write
pickup_list      = util.pickup_list
ls_files_to_json = util.ls_files_to_json

build_generator     = model.build_generator

ROOT_PATH = "./custom/models/12/"

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

DEVICE = '/cpu:0'
with tf.device(DEVICE),  tf.Session(config=config) as sess:
    input_r = tf.placeholder(tf.float32, shape=[1, 1024, 1024, 3], name='inpr')
    g_state = build_generator(input_r, name='generator')
    g_var_ls = tf.trainable_variables(scope='generator')
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(g_var_ls)
    chkpt_fname = tf.train.latest_checkpoint(ROOT_PATH)
    saver.restore(sess, chkpt_fname)
    
    inputs = ['inpr']

    outputs = ['generator/output/Tanh']
    type = ["DT_FLOAT"]
    graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, outputs)
    graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, inputs, outputs, type)
    graph_def = TransformGraph(graph_def, inputs, outputs, ['sort_by_execution_order'])

    filename = "./custom/onnx/1.pb"
    with tf.gfile.FastGFile(filename, 'wb') as f:
        f.write(graph_def.SerializeToString())

ROOT_PATH表示checkpoints的目录,导出其他模型修改它即可。这里对我来说,最大的难点就是找outputsinputs名字,最后在model.py中加入tf.identity()配合TensorBoard找到了输入输出节点的名称,TensorFlow掌握的还是太烂了,得继续努力学了。

PB转换为ONNX

这步看起来很简单,一行命令搞定

python -m tf2onnx.convert --graphdef ./custom/onnx/0.pb --output ./custom/onnx/vango.onnx --inputs inpr:0 --outputs generator/output/Tanh:0 --opset 13

但是出了一个大问题(对于菜鸟我来说),ONNX导出后,使用ONNXRuntime加载,提示说FusedPadConv2D找不到,经过艰难的分析,发现FusedPadConv2D其实是由TF这边开启padding='VALID'Conv2D转过来的,ONNX并不支持这个OP。最后我用了一个比较笨的办法,先把padding='VALID'改成padding='SAME',然后把model.py里面conv2d前的_fixed_padding都去掉,最后测试了一下效果,竟然没啥影响,那就先这样吧,哈哈哈

使用ONNX模型

使用模型时注意输入输出的格式即可,输入的数据shape是(1, 1024, 1024, 3),需要规范化到-1到1。输出的数据shape是(1, 1024, 1024, 3),范围也是-1到1,需要重新映射到0~255。

def process_vango(source, model):
    global sess_map
    
    img = np.expand_dims(((source / 255.0 - 0.5) * 2.0).astype(np.float32), axis=0)

    print(img.shape)
    if model in sess_map:
        sess = sess_map[model]
    else:
        onnx_model = onnx.load_model(os.path.join(os.path.dirname(__file__), model))
        sess = ort.InferenceSession(onnx_model.SerializeToString())
        sess.set_providers(['CPUExecutionProvider'])
        sess_map[model] = sess
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    output = sess.run([output_name], {input_name : img})


    res = output[0][0]
    return ((res + 1.0) * 0.5 * 255.0).astype(np.uint8)

更进一步

  1. 现在ONNX模型输入是固定死的,可以做成动态输入
  2. 现在是在PC上测试的,可以在其他平台包括移动端测试看看效果如何