背景
本文主要介绍风格迁移算法库GitHub - i-evi/p2gan: Patch Permutation GAN(P²-GAN)的使用以及转换成ONNX的过程中遇到的坑,下面是转换成onnx后,使用ONNXRuntime + OpenCV渲染的结果,这里的模型是基于梵高的画训练出来的
P2GAN的基本使用方式
P2GAN文档中对于训练和处理图片介绍的比较详细
训练
python train.py --model ./mymodels/vangop --style mystyle/vango_portrait.jpg --dataset H:\AIImageDatasets\VOCtrainval_06-Nov-2007\VOCdevkit\VOC2007\JPEGImages --lambda 5e-6
训练集使用的是VOC2007中的图片。风格图片你只需要指定一张,比如梵高的画作即可,如果你想训练出其他风格的模型,可以自行选择其他风格图片
图片转换
通过下面的命令,可以将./tests/
文件夹中的图片全部风格化,并输出到output
文件夹中
python render.py --model ./custom/models/0 --inp ./tests/ --oup output --size 1024
P2GAN工程化
模型保存为PB格式
import tensorflow as tf
import numpy as np
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.tools.graph_transforms import TransformGraph
import numpy as np
import tensorflow as tf
import util
import model
imbatch_read = util.imbatch_read
img_write = util.img_write
pickup_list = util.pickup_list
ls_files_to_json = util.ls_files_to_json
build_generator = model.build_generator
ROOT_PATH = "./custom/models/12/"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
DEVICE = '/cpu:0'
with tf.device(DEVICE), tf.Session(config=config) as sess:
input_r = tf.placeholder(tf.float32, shape=[1, 1024, 1024, 3], name='inpr')
g_state = build_generator(input_r, name='generator')
g_var_ls = tf.trainable_variables(scope='generator')
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(g_var_ls)
chkpt_fname = tf.train.latest_checkpoint(ROOT_PATH)
saver.restore(sess, chkpt_fname)
inputs = ['inpr']
outputs = ['generator/output/Tanh']
type = ["DT_FLOAT"]
graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, outputs)
graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, inputs, outputs, type)
graph_def = TransformGraph(graph_def, inputs, outputs, ['sort_by_execution_order'])
filename = "./custom/onnx/1.pb"
with tf.gfile.FastGFile(filename, 'wb') as f:
f.write(graph_def.SerializeToString())
ROOT_PATH
表示checkpoints的目录,导出其他模型修改它即可。这里对我来说,最大的难点就是找outputs
和inputs
名字,最后在model.py
中加入tf.identity()
配合TensorBoard找到了输入输出节点的名称,TensorFlow掌握的还是太烂了,得继续努力学了。
PB转换为ONNX
这步看起来很简单,一行命令搞定
python -m tf2onnx.convert --graphdef ./custom/onnx/0.pb --output ./custom/onnx/vango.onnx --inputs inpr:0 --outputs generator/output/Tanh:0 --opset 13
但是出了一个大问题(对于菜鸟我来说),ONNX导出后,使用ONNXRuntime加载,提示说FusedPadConv2D
找不到,经过艰难的分析,发现FusedPadConv2D
其实是由TF这边开启padding='VALID'
的Conv2D
转过来的,ONNX并不支持这个OP。最后我用了一个比较笨的办法,先把padding='VALID'
改成padding='SAME'
,然后把model.py里面conv2d前的_fixed_padding都去掉,最后测试了一下效果,竟然没啥影响,那就先这样吧,哈哈哈
使用ONNX模型
使用模型时注意输入输出的格式即可,输入的数据shape是(1, 1024, 1024, 3)
,需要规范化到-1到1。输出的数据shape是(1, 1024, 1024, 3)
,范围也是-1到1,需要重新映射到0~255。
def process_vango(source, model):
global sess_map
img = np.expand_dims(((source / 255.0 - 0.5) * 2.0).astype(np.float32), axis=0)
print(img.shape)
if model in sess_map:
sess = sess_map[model]
else:
onnx_model = onnx.load_model(os.path.join(os.path.dirname(__file__), model))
sess = ort.InferenceSession(onnx_model.SerializeToString())
sess.set_providers(['CPUExecutionProvider'])
sess_map[model] = sess
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
output = sess.run([output_name], {input_name : img})
res = output[0][0]
return ((res + 1.0) * 0.5 * 255.0).astype(np.uint8)
更进一步
- 现在ONNX模型输入是固定死的,可以做成动态输入
- 现在是在PC上测试的,可以在其他平台包括移动端测试看看效果如何