PyTorch模型转换为TensorRT模型实操

41 阅读9分钟

你想知道如何将PyTorch模型转换为TensorRT模型,核心是无法直接转换,必须通过ONNX中间格式作为桥梁,整个流程清晰可落地,下面我将以实操为主,分步讲解完整转换过程、关键注意事项和优化技巧,所有代码均可直接复刻。

一、核心转换逻辑与前置准备

1. 核心转换流程

PyTorch模型 → 导出ONNX中间模型(关键载体,解决跨框架兼容性) → 优化ONNX模型 → 转换为TensorRT引擎(.trt文件,最终部署文件)

2. 前置环境安装

需安装对应版本的依赖,确保兼容性(推荐Python 3.6-3.9,适配IoT设备如Jetson系列):

# 1. 安装PyTorch(根据设备架构选择,Jetson设备推荐官网预编译包)
pip3 install torch torchvision torchaudio

# 2. 安装ONNX相关工具
pip3 install onnx==1.12.0 onnx-simplifier==0.4.33

# 3. 安装TensorRT(已完成环境搭建可跳过,参考之前IoT教程)
# 补充:验证TensorRT是否可用
python3 -c "import tensorrt as trt; print(f'TensorRT版本:{trt.__version__}')"

二、步骤1:PyTorch模型导出为ONNX格式

这是转换的基础,导出的ONNX模型兼容性直接决定后续TensorRT转换成败,以下以轻量化模型ResNet18为例(适配IoT设备算力,也可替换为MobileNetV3、YOLOv8-Nano等)。

完整导出代码

# 文件名:pytorch2onnx.py
import torch
import torchvision.models as models

def export_pytorch_to_onnx():
    # 1. 加载PyTorch模型并设置为评估模式(必须!避免BatchNorm、Dropout等训练层影响)
    model = models.resnet18(pretrained=True).eval()
    # 适配GPU(Jetson设备自带CUDA,直接启用;无GPU可注释该行)
    if torch.cuda.is_available():
        model = model.cuda()

    # 2. 构造虚拟输入张量(dummy input)
    # 关键:输入尺寸必须与实际部署时的输入尺寸一致(如3×224×224)
    # batch_size先固定为1,后续可配置动态批量
    batch_size = 1
    input_channels = 3
    input_height = 224
    input_width = 224
    dummy_input = torch.randn(batch_size, input_channels, input_height, input_width)
    if torch.cuda.is_available():
        dummy_input = dummy_input.cuda()

    # 3. 导出ONNX模型(关键参数适配TensorRT)
    onnx_save_path = "resnet18_iot.onnx"
    torch.onnx.export(
        model,  # 待导出的PyTorch模型
        dummy_input,  # 虚拟输入张量
        onnx_save_path,  # ONNX模型保存路径
        input_names=["input_image"],  # 输入节点名称(后续TensorRT推理需对应)
        output_names=["class_output"],  # 输出节点名称(后续TensorRT推理需对应)
        opset_version=12,  # 适配TensorRT的opset版本(推荐11-13,过高易出现算子不支持)
        do_constant_folding=True,  # 折叠常量节点,优化ONNX模型体积和速度
        dynamic_axes={  # 可选:配置动态批量,支持推理时修改batch_size
            "input_image": {0: "batch_size"},
            "class_output": {0: "batch_size"}
        }
    )

    print(f"ONNX模型导出成功!保存路径:{onnx_save_path}")
    print(f"模型输入尺寸:({batch_size}, {input_channels}, {input_height}, {input_width})")

if __name__ == "__main__":
    export_pytorch_to_onnx()

运行结果与关键注意事项

  1. 运行命令:python3 pytorch2onnx.py,成功生成resnet18_iot.onnx文件。
  2. 避坑要点:
    • 模型必须调用.eval():训练模式下的层(如BatchNorm)会导致ONNX模型存在动态节点,TensorRT无法解析。
    • 虚拟输入尺寸与实际部署一致:若后续部署输入为3×192×192,此处需同步修改,避免推理时尺寸不匹配。
    • opset_version不可过高:TensorRT对高版本opset支持滞后,优先选择11-13。

三、步骤2:优化ONNX模型(可选但推荐)

导出的ONNX模型可能包含冗余节点、未折叠的常量运算,通过onnx-simplifier简化后,可减少TensorRT转换时间和最终引擎体积。

优化命令(一行执行)

# 格式:python3 -m onnxsim 原ONNX文件 优化后ONNX文件
python3 -m onnxsim resnet18_iot.onnx resnet18_iot_simplified.onnx

优化效果

  • 模型体积缩小30%-50%,去除冗余计算节点。
  • 减少TensorRT解析时的报错概率,提升转换效率。
  • 优化后的模型可直接用于后续转换,功能与原模型一致。

四、步骤3:ONNX模型转换为TensorRT引擎(核心步骤)

有两种常用方式,方式1(Python API)灵活可控,支持量化优化(适配IoT设备);方式2(trtexec命令行)简单快捷,适合快速验证

方式1:Python API转换(推荐,支持INT8/FP16量化)

该方式可自定义量化策略,最大化IoT设备推理性能,完整代码如下(包含INT8量化校准,适配算力有限设备)。

1. 完整转换代码

# 文件名:onnx2tensorrt.py
import tensorrt as trt
import os
import cv2
import numpy as np

# 配置全局参数
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)  # 日志级别:减少IoT设备输出冗余
ONNX_FILE = "resnet18_iot_simplified.onnx"
TRT_ENGINE_FILE = "resnet18_iot.trt"
MAX_WORKSPACE_SIZE = 1 << 30  # 最大工作空间:1GB(Jetson Nano 4GB适配,2GB设备改为1<<29)

# 自定义INT8校准器(适配IoT设备,减少内存占用,提升推理速度)
class IoTInt8Calibrator(trt.IInt8Calibrator):
    def __init__(self, calibration_data_path, cache_file="calibration.cache", batch_size=8):
        trt.IInt8Calibrator.__init__(self)
        self.cache_file = cache_file  # 校准缓存文件,避免重复校准
        self.batch_size = batch_size  # 校准批量,根据IoT设备内存调整
        self.calibration_data = self._load_and_preprocess_calibration_data(calibration_data_path)
        self.batch_index = 0

        # 分配设备输入内存(减少内存拷贝开销)
        self.input_shape = (self.batch_size, 3, 224, 224)
        self.device_input = trt.cuda.DeviceMemory(np.prod(self.input_shape) * 4)  # 4字节/浮点数

    def _load_and_preprocess_calibration_data(self, data_path):
        """加载校准集并预处理(与推理时步骤完全一致,关键!)"""
        # 校准集要求:100-500张图片,与训练数据分布一致,覆盖所有类别
        calibration_files = [os.path.join(data_path, f) for f in os.listdir(data_path) 
                             if f.endswith((".jpg", ".png"))][:500]  # 限制数量,避免IoT内存溢出

        calibration_data = []
        for file in calibration_files:
            img = cv2.imread(file)
            if img is None:
                continue
            # 预处理:与PyTorch训练/后续推理一致(缩放、转置、归一化)
            img = cv2.resize(img, (224, 224))
            img = img.transpose((2, 0, 1))  # HWC → CHW(PyTorch输入格式)
            img = img / 255.0  # 归一化(与训练时保持一致)
            calibration_data.append(img.astype(np.float32))

        return np.array(calibration_data, dtype=np.float32)

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names):
        """获取批量校准数据,返回None表示校准结束"""
        if self.batch_index + self.batch_size > len(self.calibration_data):
            return None
        batch_data = self.calibration_data[self.batch_index:self.batch_index+self.batch_size]
        self.batch_index += self.batch_size

        # 复制数据到设备内存
        np.copyto(self.device_input, batch_data.ravel())
        return [int(self.device_input)]

    def read_calibration_cache(self):
        """读取校准缓存,避免重复校准"""
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()
        return None

    def write_calibration_cache(self, cache):
        """保存校准缓存,供后续复用"""
        with open(self.cache_file, "wb") as f:
            f.write(cache)

# 构建TensorRT引擎
def build_trt_engine():
    # 检查ONNX文件是否存在
    if not os.path.exists(ONNX_FILE):
        raise FileNotFoundError(f"ONNX文件 {ONNX_FILE} 不存在,请先导出并优化ONNX模型")

    # 1. 创建TensorRT构建器、网络、解析器
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, TRT_LOGGER)

    # 2. 解析ONNX模型
    with open(ONNX_FILE, "rb") as f:
        if not parser.parse(f.read()):
            print("ONNX模型解析失败!以下是错误详情:")
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    # 3. 配置构建参数(适配IoT设备,核心优化)
    config = builder.create_builder_config()
    config.max_workspace_size = MAX_WORKSPACE_SIZE

    # 4. 启用量化优化(二选一,优先INT8量化,极致适配IoT)
    # 选项A:启用INT8量化(需校准集,速度提升3-10倍,精度损失≤3%)
    if builder.platform_has_fast_int8:
        config.set_flag(trt.BuilderFlag.INT8)
        # 传入校准器(需提前准备校准集文件夹./calibration_data)
        calibrator = IoTInt8Calibrator(calibration_data_path="./calibration_data")
        config.int8_calibrator = calibrator
        print("已启用INT8量化,开始加载校准集...")
    # 选项B:启用FP16量化(无需校准集,无明显精度损失,速度提升2-5倍)
    elif builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        print("已启用FP16量化...")

    # 5. 构建并保存TensorRT引擎
    print("开始构建TensorRT引擎(IoT设备可能耗时5-10分钟,请耐心等待)...")
    serialized_engine = builder.build_serialized_network(network, config)
    if not serialized_engine:
        print("TensorRT引擎构建失败!")
        return None

    with open(TRT_ENGINE_FILE, "wb") as f:
        f.write(serialized_engine)
    print(f"TensorRT引擎构建成功!保存路径:{TRT_ENGINE_FILE}")
    return serialized_engine

if __name__ == "__main__":
    build_trt_engine()

2. 前置准备与运行

  1. 准备校准集:创建./calibration_data文件夹,放入100-500张与训练数据分布一致的图片(分类任务覆盖所有类别,检测任务包含各类目标)。
  2. 运行命令:python3 onnx2tensorrt.py,成功生成resnet18_iot.trt引擎文件。
  3. 关键说明:
    • INT8量化是IoT设备的最优选择,在精度损失可控的前提下,最大化推理速度、降低内存占用。
    • 生成的.trt引擎与设备架构绑定(如Jetson Nano生成的引擎无法在RK3588上运行),需在目标IoT设备上完成转换。

方式2:trtexec命令行转换(简单快捷,快速验证)

trtexec是TensorRT自带的命令行工具,适合快速验证ONNX模型是否可转换,无需编写复杂代码。

1. 基本命令(FP16量化示例)

# 格式:trtexec --onnx=优化后ONNX文件 --saveEngine=输出TRT引擎文件 --fp16
trtexec --onnx=resnet18_iot_simplified.onnx --saveEngine=resnet18_iot_trtexec.trt --fp16 --workspace=1024

2. INT8量化命令(需校准集)

trtexec --onnx=resnet18_iot_simplified.onnx --saveEngine=resnet18_iot_int8.trt --int8 --calib=./calibration_data --workspace=1024

3. 参数说明

  • --workspace:工作空间大小,单位MB(Jetson Nano推荐1024,即1GB)。
  • --fp16/--int8:启用量化模式。
  • --verbose:输出详细日志,用于排查转换错误。

五、步骤4:验证TensorRT引擎有效性

转换完成后,通过简单的推理代码验证引擎是否可用,同时查看推理耗时(适配IoT设备)。

验证代码

# 文件名:verify_trt_engine.py
import tensorrt as trt
import cv2
import numpy as np
import time

class TRTInferencer:
    def __init__(self, engine_file):
        self.logger = trt.Logger(trt.Logger.WARNING)
        self.runtime = trt.Runtime(self.logger)
        self.engine = self._load_trt_engine(engine_file)
        self.context = self.engine.create_execution_context()
        self._init_input_output()

    def _load_trt_engine(self, engine_file):
        """加载TensorRT引擎"""
        if not os.path.exists(engine_file):
            raise FileNotFoundError(f"TensorRT引擎文件 {engine_file} 不存在")
        with open(engine_file, "rb") as f:
            serialized_engine = f.read()
        return self.runtime.deserialize_cuda_engine(serialized_engine)

    def _init_input_output(self):
        """初始化输入输出节点信息"""
        self.input_name = "input_image"
        self.output_name = "class_output"
        self.input_idx = self.engine.get_binding_index(self.input_name)
        self.output_idx = self.engine.get_binding_index(self.output_name)
        self.input_shape = self.engine.get_binding_shape(self.input_idx)
        self.output_shape = self.engine.get_binding_shape(self.output_idx)

    def preprocess(self, img_path):
        """预处理(与训练/校准一致)"""
        img = cv2.imread(img_path)
        img = cv2.resize(img, (224, 224))
        img = img.transpose((2, 0, 1)) / 255.0
        img = np.expand_dims(img, axis=0).astype(np.float32)
        return np.ascontiguousarray(img)

    def infer(self, img_path):
        """执行推理并返回结果"""
        input_data = self.preprocess(img_path)

        # 分配内存
        input_host = input_data.ravel()
        output_host = np.empty(np.prod(self.output_shape), dtype=np.float32)
        input_device = trt.cuda.DeviceMemory(len(input_host) * 4)
        output_device = trt.cuda.DeviceMemory(len(output_host) * 4)

        # 异步推理
        stream = trt.cuda.Stream()
        input_device.copy_from_host(input_host, stream)

        # 计时推理
        start_time = time.time()
        self.context.execute_async_v2(
            bindings=[int(input_device), int(output_device)],
            stream_handle=stream.handle
        )
        stream.synchronize()
        infer_time = (time.time() - start_time) * 1000  # 转换为毫秒

        # 拷贝结果
        output_device.copy_to_host(output_host, stream)
        output_data = output_host.reshape(self.output_shape)
        top1_class = np.argmax(output_data[0])

        return {
            "top1_class": top1_class,
            "infer_time_ms": round(infer_time, 2)
        }

if __name__ == "__main__":
    inferencer = TRTInferencer("resnet18_iot.trt")
    result = inferencer.infer("test.jpg")
    print(f"Top-1类别:{result['top1_class']}")
    print(f"单帧推理耗时:{result['infer_time_ms']}ms")

预期结果(Jetson Nano 4GB)

  • INT8量化:推理耗时≤20ms,内存占用≤300MB。
  • FP16量化:推理耗时≤30ms,内存占用≤400MB。

六、常见转换问题排查

  1. 报错:"ONNX node xxx is not supported by TensorRT"

    • 原因:ONNX模型包含TensorRT不支持的算子,或opset版本过高。
    • 解决方案:降低opset版本(如从14降至12)、用onnx-simplifier重新优化、替换不支持的算子(如用普通Conv2d替换DepthwiseConv2d)。
  2. 报错:"libnvinfer.so.8: cannot open shared object file"

    • 原因:TensorRT环境变量未配置生效。
    • 解决方案:执行export LD_LIBRARY_PATH=/usr/local/TensorRT/lib:$LD_LIBRARY_PATH,或永久配置环境变量。
  3. INT8量化精度损失严重(>5%)

    • 原因:校准集数量不足、分布不一致,或预处理步骤与训练不符。
    • 解决方案:增加校准集至200张以上、确保校准集覆盖所有类别、严格对齐预处理步骤(归一化、缩放、转置顺序)。
  4. 引擎构建超时(IoT设备)

    • 原因:工作空间设置过大、输入尺寸过大、批量校准值过高。
    • 解决方案:减小工作空间(如512MB)、缩小输入尺寸(如224→192)、降低校准批量(如8→4)。