onnx设置动态batch/修改onnx的batch

3,058 阅读1分钟

动态batch设置方法,列举两种

  • 训练框架(例如pytorch)导出onnx模型时 设置动态batch
def export_onnx(model, input_hwc, output_file, input_names=['input'],
            output_names=['output'], show=False, opset_version=9, dynamic=True):
    if dynamic:
        dynamic_axes = {
            'input': {
                0: 'batch',
            },
            'output': {
                0: 'batch'
            }
        }
    else:
        dynamic_axes = {}

    h, w, c = input_hwc
    input_shape = [1, c, h, w]
    one_img = torch.randn(input_shape)

    # register_extra_symbolics(opset_version)
    with torch.no_grad():
        torch.onnx.export(
            model.cpu().eval(),
            one_img,
            output_file,
            input_names=input_names,
            output_names=output_names,
            export_params=True,
            keep_initializers_as_inputs=True,
            verbose=show,
            dynamic_axes=dynamic_axes,
            opset_version=opset_version)
    print('>>> finish export onnx:', output_file)

  • 通过onnx库修改onnx模型的batch
# 安装onnx:pip install onnx
import onnx
def change_input_dim(model):
    # Use some symbolic name not used for any other dimension
    sym_batch_dim = "N"
    # or an actal value
    actual_batch_dim = "4" 

    # The following code changes the first dimension of every input to be batch-dim
    # Modify as appropriate ... note that this requires all inputs to
    # have the same batch_dim 
    inputs = model.graph.input
    for input in inputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = input.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim
        # or update it to be an actual value:
        # dim1.dim_value = actual_batch_dim


def apply(transform, infile, outfile):
    model = onnx.load(infile)
    transform(model)
    onnx.save(model, outfile)

apply(change_input_dim, onnx_pth, save_pth)

onnxruntime 动态batch 推理

# 安装onnxruntime: pip install onnxruntime
import onnxruntime as ort
import numpy as np
import torch 
x1 = torch.rand(1,3,112,112)
ort_sess1 = ort.InferenceSession(path_to_onnx_model)
outputs1 = ort_sess1.run(None, {'input.1': x2.numpy()})