使用ONNX Runtime在Python中进行模型推理

1,209 阅读3分钟

ONNX Runtime(ORT)是一种高性能的推理引擎,支持多种深度学习框架,如PyTorch、TensorFlow和scikit-learn。以下是如何在Python中安装和使用ONNX Runtime的简要指南。

安装ONNX Runtime

ONNX Runtime有两个主要的Python包:CPU版本和GPU版本。根据你的硬件选择合适的版本。

  • CPU版本:适用于Arm-based CPU和macOS。

    pip install onnxruntime
    
  • GPU版本(CUDA 12.x):默认支持CUDA 12.x。

    pip install onnxruntime-gpu
    
  • GPU版本(CUDA 11.8):需要从Azure DevOps Feed安装。

    pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/
    

安装ONNX导出工具

PyTorch

PyTorch自带ONNX支持,直接安装PyTorch即可。

pip install torch

TensorFlow

需要额外安装tf2onnx来支持ONNX导出。

pip install tf2onnx

scikit-learn

需要安装skl2onnx来支持ONNX导出。

pip install skl2onnx

快速开始示例

PyTorch CV示例

  1. 导出模型

    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch模型,device是设备(如GPU或CPU)
    model = ...  # 初始化你的模型
    device = ...  # 设备
    
    input_data = torch.randn(1, 28, 28).to(device)
    onnx.export(model, input_data, "fashion_mnist_model.onnx", 
                input_names=['input'], output_names=['output'])
    
  2. 加载ONNX模型并进行推理

    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("fashion_mnist_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('fashion_mnist_model.onnx')
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    outputs = ort_sess.run(None, {'input': x.numpy()})
    

PyTorch NLP示例

  1. 导出模型

    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch NLP模型
    model = ...  # 初始化你的模型
    
    text = "示例文本"
    text_tensor = torch.tensor(text_pipeline(text))  # 假设text_pipeline是文本预处理函数
    offsets = torch.tensor([0])
    
    onnx.export(model, (text_tensor, offsets), "ag_news_model.onnx", 
                input_names=['input', 'offsets'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    
  2. 加载ONNX模型并进行推理

    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("ag_news_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('ag_news_model.onnx')
    
    # 假设text是输入文本,offsets是偏移量
    text = ...  # 初始化输入文本
    offsets = torch.tensor([0])
    
    outputs = ort_sess.run(None, {'input': text.numpy(), 'offsets': offsets.numpy()})
    

TensorFlow CV示例

  1. 导出模型

    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50
    import tf2onnx
    
    # 加载预训练模型
    model = ResNet50(weights='imagenet')
    
    # 转换为ONNX模型
    spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
    output_path = model.name + ".onnx"
    model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
    
  2. 加载ONNX模型并进行推理

    import onnxruntime as rt
    
    # 创建推理会话
    providers = ['CPUExecutionProvider']
    m = rt.InferenceSession(output_path, providers=providers)
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    output_names = [n.name for n in model_proto.graph.output]
    onnx_pred = m.run(output_names, {"input": x})
    

scikit-learn CV示例

  1. 导出模型

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from skl2onnx import convert_sklearn
    
    # 加载iris数据集
    iris = load_iris()
    X, y = iris.data, iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    # 训练逻辑回归模型
    clr = LogisticRegression()
    clr.fit(X_train, y_train)
    
    # 转换为ONNX模型
    initial_type = [('float_input', FloatTensorType([None, 4]))]
    onx = convert_sklearn(clr, initial_types=initial_type)
    with open("logreg_iris.onnx", "wb") as f:
        f.write(onx.SerializeToString())
    
  2. 加载ONNX模型并进行推理

    import numpy
    import onnxruntime as rt
    
    # 创建推理会话
    sess = rt.InferenceSession("logreg_iris.onnx")
    input_name = sess.get_inputs()[0].name
    
    # 假设X_test是测试数据
    pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0]
    

通过这些示例,你可以轻松地将不同框架的模型导出为ONNX格式,并使用ONNX Runtime进行高效的推理。