你想知道如何将PyTorch模型转换为TensorRT模型,核心是无法直接转换,必须通过ONNX中间格式作为桥梁,整个流程清晰可落地,下面我将以实操为主,分步讲解完整转换过程、关键注意事项和优化技巧,所有代码均可直接复刻。
一、核心转换逻辑与前置准备
1. 核心转换流程
PyTorch模型 → 导出ONNX中间模型(关键载体,解决跨框架兼容性) → 优化ONNX模型 → 转换为TensorRT引擎(.trt文件,最终部署文件)
2. 前置环境安装
需安装对应版本的依赖,确保兼容性(推荐Python 3.6-3.9,适配IoT设备如Jetson系列):
# 1. 安装PyTorch(根据设备架构选择,Jetson设备推荐官网预编译包)
pip3 install torch torchvision torchaudio
# 2. 安装ONNX相关工具
pip3 install onnx==1.12.0 onnx-simplifier==0.4.33
# 3. 安装TensorRT(已完成环境搭建可跳过,参考之前IoT教程)
# 补充:验证TensorRT是否可用
python3 -c "import tensorrt as trt; print(f'TensorRT版本:{trt.__version__}')"
二、步骤1:PyTorch模型导出为ONNX格式
这是转换的基础,导出的ONNX模型兼容性直接决定后续TensorRT转换成败,以下以轻量化模型ResNet18为例(适配IoT设备算力,也可替换为MobileNetV3、YOLOv8-Nano等)。
完整导出代码
# 文件名:pytorch2onnx.py
import torch
import torchvision.models as models
def export_pytorch_to_onnx():
# 1. 加载PyTorch模型并设置为评估模式(必须!避免BatchNorm、Dropout等训练层影响)
model = models.resnet18(pretrained=True).eval()
# 适配GPU(Jetson设备自带CUDA,直接启用;无GPU可注释该行)
if torch.cuda.is_available():
model = model.cuda()
# 2. 构造虚拟输入张量(dummy input)
# 关键:输入尺寸必须与实际部署时的输入尺寸一致(如3×224×224)
# batch_size先固定为1,后续可配置动态批量
batch_size = 1
input_channels = 3
input_height = 224
input_width = 224
dummy_input = torch.randn(batch_size, input_channels, input_height, input_width)
if torch.cuda.is_available():
dummy_input = dummy_input.cuda()
# 3. 导出ONNX模型(关键参数适配TensorRT)
onnx_save_path = "resnet18_iot.onnx"
torch.onnx.export(
model, # 待导出的PyTorch模型
dummy_input, # 虚拟输入张量
onnx_save_path, # ONNX模型保存路径
input_names=["input_image"], # 输入节点名称(后续TensorRT推理需对应)
output_names=["class_output"], # 输出节点名称(后续TensorRT推理需对应)
opset_version=12, # 适配TensorRT的opset版本(推荐11-13,过高易出现算子不支持)
do_constant_folding=True, # 折叠常量节点,优化ONNX模型体积和速度
dynamic_axes={ # 可选:配置动态批量,支持推理时修改batch_size
"input_image": {0: "batch_size"},
"class_output": {0: "batch_size"}
}
)
print(f"ONNX模型导出成功!保存路径:{onnx_save_path}")
print(f"模型输入尺寸:({batch_size}, {input_channels}, {input_height}, {input_width})")
if __name__ == "__main__":
export_pytorch_to_onnx()
运行结果与关键注意事项
- 运行命令:
python3 pytorch2onnx.py,成功生成resnet18_iot.onnx文件。 - 避坑要点:
- 模型必须调用
.eval():训练模式下的层(如BatchNorm)会导致ONNX模型存在动态节点,TensorRT无法解析。 - 虚拟输入尺寸与实际部署一致:若后续部署输入为3×192×192,此处需同步修改,避免推理时尺寸不匹配。
- opset_version不可过高:TensorRT对高版本opset支持滞后,优先选择11-13。
- 模型必须调用
三、步骤2:优化ONNX模型(可选但推荐)
导出的ONNX模型可能包含冗余节点、未折叠的常量运算,通过onnx-simplifier简化后,可减少TensorRT转换时间和最终引擎体积。
优化命令(一行执行)
# 格式:python3 -m onnxsim 原ONNX文件 优化后ONNX文件
python3 -m onnxsim resnet18_iot.onnx resnet18_iot_simplified.onnx
优化效果
- 模型体积缩小30%-50%,去除冗余计算节点。
- 减少TensorRT解析时的报错概率,提升转换效率。
- 优化后的模型可直接用于后续转换,功能与原模型一致。
四、步骤3:ONNX模型转换为TensorRT引擎(核心步骤)
有两种常用方式,方式1(Python API)灵活可控,支持量化优化(适配IoT设备);方式2(trtexec命令行)简单快捷,适合快速验证。
方式1:Python API转换(推荐,支持INT8/FP16量化)
该方式可自定义量化策略,最大化IoT设备推理性能,完整代码如下(包含INT8量化校准,适配算力有限设备)。
1. 完整转换代码
# 文件名:onnx2tensorrt.py
import tensorrt as trt
import os
import cv2
import numpy as np
# 配置全局参数
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) # 日志级别:减少IoT设备输出冗余
ONNX_FILE = "resnet18_iot_simplified.onnx"
TRT_ENGINE_FILE = "resnet18_iot.trt"
MAX_WORKSPACE_SIZE = 1 << 30 # 最大工作空间:1GB(Jetson Nano 4GB适配,2GB设备改为1<<29)
# 自定义INT8校准器(适配IoT设备,减少内存占用,提升推理速度)
class IoTInt8Calibrator(trt.IInt8Calibrator):
def __init__(self, calibration_data_path, cache_file="calibration.cache", batch_size=8):
trt.IInt8Calibrator.__init__(self)
self.cache_file = cache_file # 校准缓存文件,避免重复校准
self.batch_size = batch_size # 校准批量,根据IoT设备内存调整
self.calibration_data = self._load_and_preprocess_calibration_data(calibration_data_path)
self.batch_index = 0
# 分配设备输入内存(减少内存拷贝开销)
self.input_shape = (self.batch_size, 3, 224, 224)
self.device_input = trt.cuda.DeviceMemory(np.prod(self.input_shape) * 4) # 4字节/浮点数
def _load_and_preprocess_calibration_data(self, data_path):
"""加载校准集并预处理(与推理时步骤完全一致,关键!)"""
# 校准集要求:100-500张图片,与训练数据分布一致,覆盖所有类别
calibration_files = [os.path.join(data_path, f) for f in os.listdir(data_path)
if f.endswith((".jpg", ".png"))][:500] # 限制数量,避免IoT内存溢出
calibration_data = []
for file in calibration_files:
img = cv2.imread(file)
if img is None:
continue
# 预处理:与PyTorch训练/后续推理一致(缩放、转置、归一化)
img = cv2.resize(img, (224, 224))
img = img.transpose((2, 0, 1)) # HWC → CHW(PyTorch输入格式)
img = img / 255.0 # 归一化(与训练时保持一致)
calibration_data.append(img.astype(np.float32))
return np.array(calibration_data, dtype=np.float32)
def get_batch_size(self):
return self.batch_size
def get_batch(self, names):
"""获取批量校准数据,返回None表示校准结束"""
if self.batch_index + self.batch_size > len(self.calibration_data):
return None
batch_data = self.calibration_data[self.batch_index:self.batch_index+self.batch_size]
self.batch_index += self.batch_size
# 复制数据到设备内存
np.copyto(self.device_input, batch_data.ravel())
return [int(self.device_input)]
def read_calibration_cache(self):
"""读取校准缓存,避免重复校准"""
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
return None
def write_calibration_cache(self, cache):
"""保存校准缓存,供后续复用"""
with open(self.cache_file, "wb") as f:
f.write(cache)
# 构建TensorRT引擎
def build_trt_engine():
# 检查ONNX文件是否存在
if not os.path.exists(ONNX_FILE):
raise FileNotFoundError(f"ONNX文件 {ONNX_FILE} 不存在,请先导出并优化ONNX模型")
# 1. 创建TensorRT构建器、网络、解析器
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
# 2. 解析ONNX模型
with open(ONNX_FILE, "rb") as f:
if not parser.parse(f.read()):
print("ONNX模型解析失败!以下是错误详情:")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# 3. 配置构建参数(适配IoT设备,核心优化)
config = builder.create_builder_config()
config.max_workspace_size = MAX_WORKSPACE_SIZE
# 4. 启用量化优化(二选一,优先INT8量化,极致适配IoT)
# 选项A:启用INT8量化(需校准集,速度提升3-10倍,精度损失≤3%)
if builder.platform_has_fast_int8:
config.set_flag(trt.BuilderFlag.INT8)
# 传入校准器(需提前准备校准集文件夹./calibration_data)
calibrator = IoTInt8Calibrator(calibration_data_path="./calibration_data")
config.int8_calibrator = calibrator
print("已启用INT8量化,开始加载校准集...")
# 选项B:启用FP16量化(无需校准集,无明显精度损失,速度提升2-5倍)
elif builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print("已启用FP16量化...")
# 5. 构建并保存TensorRT引擎
print("开始构建TensorRT引擎(IoT设备可能耗时5-10分钟,请耐心等待)...")
serialized_engine = builder.build_serialized_network(network, config)
if not serialized_engine:
print("TensorRT引擎构建失败!")
return None
with open(TRT_ENGINE_FILE, "wb") as f:
f.write(serialized_engine)
print(f"TensorRT引擎构建成功!保存路径:{TRT_ENGINE_FILE}")
return serialized_engine
if __name__ == "__main__":
build_trt_engine()
2. 前置准备与运行
- 准备校准集:创建
./calibration_data文件夹,放入100-500张与训练数据分布一致的图片(分类任务覆盖所有类别,检测任务包含各类目标)。 - 运行命令:
python3 onnx2tensorrt.py,成功生成resnet18_iot.trt引擎文件。 - 关键说明:
- INT8量化是IoT设备的最优选择,在精度损失可控的前提下,最大化推理速度、降低内存占用。
- 生成的
.trt引擎与设备架构绑定(如Jetson Nano生成的引擎无法在RK3588上运行),需在目标IoT设备上完成转换。
方式2:trtexec命令行转换(简单快捷,快速验证)
trtexec是TensorRT自带的命令行工具,适合快速验证ONNX模型是否可转换,无需编写复杂代码。
1. 基本命令(FP16量化示例)
# 格式:trtexec --onnx=优化后ONNX文件 --saveEngine=输出TRT引擎文件 --fp16
trtexec --onnx=resnet18_iot_simplified.onnx --saveEngine=resnet18_iot_trtexec.trt --fp16 --workspace=1024
2. INT8量化命令(需校准集)
trtexec --onnx=resnet18_iot_simplified.onnx --saveEngine=resnet18_iot_int8.trt --int8 --calib=./calibration_data --workspace=1024
3. 参数说明
--workspace:工作空间大小,单位MB(Jetson Nano推荐1024,即1GB)。--fp16/--int8:启用量化模式。--verbose:输出详细日志,用于排查转换错误。
五、步骤4:验证TensorRT引擎有效性
转换完成后,通过简单的推理代码验证引擎是否可用,同时查看推理耗时(适配IoT设备)。
验证代码
# 文件名:verify_trt_engine.py
import tensorrt as trt
import cv2
import numpy as np
import time
class TRTInferencer:
def __init__(self, engine_file):
self.logger = trt.Logger(trt.Logger.WARNING)
self.runtime = trt.Runtime(self.logger)
self.engine = self._load_trt_engine(engine_file)
self.context = self.engine.create_execution_context()
self._init_input_output()
def _load_trt_engine(self, engine_file):
"""加载TensorRT引擎"""
if not os.path.exists(engine_file):
raise FileNotFoundError(f"TensorRT引擎文件 {engine_file} 不存在")
with open(engine_file, "rb") as f:
serialized_engine = f.read()
return self.runtime.deserialize_cuda_engine(serialized_engine)
def _init_input_output(self):
"""初始化输入输出节点信息"""
self.input_name = "input_image"
self.output_name = "class_output"
self.input_idx = self.engine.get_binding_index(self.input_name)
self.output_idx = self.engine.get_binding_index(self.output_name)
self.input_shape = self.engine.get_binding_shape(self.input_idx)
self.output_shape = self.engine.get_binding_shape(self.output_idx)
def preprocess(self, img_path):
"""预处理(与训练/校准一致)"""
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224))
img = img.transpose((2, 0, 1)) / 255.0
img = np.expand_dims(img, axis=0).astype(np.float32)
return np.ascontiguousarray(img)
def infer(self, img_path):
"""执行推理并返回结果"""
input_data = self.preprocess(img_path)
# 分配内存
input_host = input_data.ravel()
output_host = np.empty(np.prod(self.output_shape), dtype=np.float32)
input_device = trt.cuda.DeviceMemory(len(input_host) * 4)
output_device = trt.cuda.DeviceMemory(len(output_host) * 4)
# 异步推理
stream = trt.cuda.Stream()
input_device.copy_from_host(input_host, stream)
# 计时推理
start_time = time.time()
self.context.execute_async_v2(
bindings=[int(input_device), int(output_device)],
stream_handle=stream.handle
)
stream.synchronize()
infer_time = (time.time() - start_time) * 1000 # 转换为毫秒
# 拷贝结果
output_device.copy_to_host(output_host, stream)
output_data = output_host.reshape(self.output_shape)
top1_class = np.argmax(output_data[0])
return {
"top1_class": top1_class,
"infer_time_ms": round(infer_time, 2)
}
if __name__ == "__main__":
inferencer = TRTInferencer("resnet18_iot.trt")
result = inferencer.infer("test.jpg")
print(f"Top-1类别:{result['top1_class']}")
print(f"单帧推理耗时:{result['infer_time_ms']}ms")
预期结果(Jetson Nano 4GB)
- INT8量化:推理耗时≤20ms,内存占用≤300MB。
- FP16量化:推理耗时≤30ms,内存占用≤400MB。
六、常见转换问题排查
-
报错:"ONNX node xxx is not supported by TensorRT"
- 原因:ONNX模型包含TensorRT不支持的算子,或opset版本过高。
- 解决方案:降低opset版本(如从14降至12)、用
onnx-simplifier重新优化、替换不支持的算子(如用普通Conv2d替换DepthwiseConv2d)。
-
报错:"libnvinfer.so.8: cannot open shared object file"
- 原因:TensorRT环境变量未配置生效。
- 解决方案:执行
export LD_LIBRARY_PATH=/usr/local/TensorRT/lib:$LD_LIBRARY_PATH,或永久配置环境变量。
-
INT8量化精度损失严重(>5%)
- 原因:校准集数量不足、分布不一致,或预处理步骤与训练不符。
- 解决方案:增加校准集至200张以上、确保校准集覆盖所有类别、严格对齐预处理步骤(归一化、缩放、转置顺序)。
-
引擎构建超时(IoT设备)
- 原因:工作空间设置过大、输入尺寸过大、批量校准值过高。
- 解决方案:减小工作空间(如512MB)、缩小输入尺寸(如224→192)、降低校准批量(如8→4)。