动态batch设置方法,列举两种
- 训练框架(例如pytorch)导出onnx模型时 设置动态batch
def export_onnx(model, input_hwc, output_file, input_names=['input'],
output_names=['output'], show=False, opset_version=9, dynamic=True):
if dynamic:
dynamic_axes = {
'input': {
0: 'batch',
},
'output': {
0: 'batch'
}
}
else:
dynamic_axes = {}
h, w, c = input_hwc
input_shape = [1, c, h, w]
one_img = torch.randn(input_shape)
with torch.no_grad():
torch.onnx.export(
model.cpu().eval(),
one_img,
output_file,
input_names=input_names,
output_names=output_names,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
dynamic_axes=dynamic_axes,
opset_version=opset_version)
print('>>> finish export onnx:', output_file)
import onnx
def change_input_dim(model):
sym_batch_dim = "N"
actual_batch_dim = "4"
inputs = model.graph.input
for input in inputs:
dim1 = input.type.tensor_type.shape.dim[0]
dim1.dim_param = sym_batch_dim
def apply(transform, infile, outfile):
model = onnx.load(infile)
transform(model)
onnx.save(model, outfile)
apply(change_input_dim, onnx_pth, save_pth)
onnxruntime 动态batch 推理
import onnxruntime as ort
import numpy as np
import torch
x1 = torch.rand(1,3,112,112)
ort_sess1 = ort.InferenceSession(path_to_onnx_model)
outputs1 = ort_sess1.run(None, {'input.1': x2.numpy()})