RealBasicVSR如何将模型转换成onnx

1,207 阅读1分钟

背景

为了方便RealBasicVSR部署到其他平台,尝试将RealBasicVSR的模型转换成onnx格式,记录一下过程中遇到的问题和解决方案

导出onnx模型

我直接在inference_realbasicvsr.py中的outputs = model(inputs, test_mode=True)['output'].cpu()后面加入下列代码导出onnx模型

torch.onnx.export(model,
                inputs,
                'model.onnx',
                input_names= ['input'],
                output_names=['output'],
                opset_version=11,
                dynamic_axes={'input' : {0 : 'batch_size', 3 : 'w', 4 : 'h'}, 'output' : {0 : 'batch_size', 3 : 'dstw', 4 : 'dsth'}})

通过dynamic_axes设置输入输出为动态大小

使用ONNX模型

上一个最小可用的例子,就是简单的加载图片和模型,然后处理,显示和保存,使用了numpy的一些方法让数据的shape和模型输入保持一致

import imp
import onnxruntime as ort
import numpy as np
import onnx
import cv2

img = cv2.imread("./data/demos/anim2.jpg")
img = np.expand_dims((img/255.0).astype(np.float32).transpose(2,0,1), axis=0)

imgs = np.array([img])
print(imgs.shape)

onnx_model = onnx.load_model("model.onnx")
sess = ort.InferenceSession(onnx_model.SerializeToString())
sess.set_providers(['CPUExecutionProvider'])
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[1].name

print(sess.get_outputs()[0].shape)
output = sess.run([output_name], {input_name : imgs})
print(output[0].shape)

res = output[0][0][0].transpose(1, 2, 0)
cv2.imshow("res", res)
cv2.imwrite("results/anims/anim03.jpg", (res * 255).astype(np.uint8))

cv2.waitKey(0)

遇到的问题

SRGAN model does not support forward_train function.

这个按照我的理解是test_mode字段没有传下去,于是直接把下面的代码改了下,原本test_mode默认为False,我改成了True,不确定是否有其他影响,只能说目前可用

@auto_fp16(apply_to=('lq', ))
def forward(self, lq, gt=None, test_mode=True, **kwargs):
    """Forward function.

    Args:
        lq (Tensor): Input lq images.
        gt (Tensor): Ground-truth image. Default: None.
        test_mode (bool): Whether in test mode or not. Default: False.
        kwargs (dict): Other arguments.
    """
    if test_mode:
        return self.forward_test(lq, gt, **kwargs)

    raise ValueError(
        'SRGAN model does not support `forward_train` function.')

output出来的数据shape和输入的一样

这个找了很久,最后发现模型其实有2个输出,应该导出模型的时候都命名,然后选择第二个就好了。上面的使用代码中直接选择了第二个output

output_name = sess.get_outputs()[1].name