torch转ONNX设置动态batch_size

1,442 阅读1分钟

需求

ONNX往往会设定单帧图像推理,那如果我们希望使用batch size去推理呢?

解决

首先从.pth转ONNX文件之前,你需要检查每个网络模块的forward,查看forward中是否存在reshapeview等调整输出feature map形状的操作。确保如果batch size不为1时,我们能否正确输出第一维是batch size的。 例如RepVGG的mmClassification版本中,GlobalAveragePooling的forward中可以做出绿色高亮的修改,以让Pooling后的feature map第二维保持不变,第一维跟随batch size发生变化:

image.png

随后先将.pth加载并通过torch.save(model, save_path)保存一个新的.pth文件以使之生效。

其次通过如下方法导出ONNX,记得标注dynamic:

model = torch.load(args.ckpt_path)
model.eval()
model.forward = model.simple_test
dummy_input = torch.ones([1, 3, 224, 224])
torch.onnx.export(
    model, dummy_input, args.output_onnx_path, export_params=True, opset_version=13, input_names=["input"], output_names=["output"], 
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

然后使用Netron可视化一下,看看有没有因为dynamic_axes导致的额外reshape等分支,最好没有额外分支。

最后,在

from onnxsim import simplify
onnx_model, check = simplify(onnx_model, dynamic_input_shape=True, input_shapes={"input": [32, 3, 224, 224]})

中务必加入dynamic_input_shape=True, input_shapes={"input": [batch_size, 3, 224, 224]}两个参数。

既可以按照batch用ONNX做推理。