为什么学这个
在边缘计算设备(如 Jetson Orin NX)上部署深度学习模型时,算力和显存往往是最大的瓶颈。虽然 PyTorch 在训练阶段非常灵活,但在推理阶段直接跑原生模型效率太低。为了榨干硬件性能,将模型转换为 TensorRT 引擎是必经之路。
最近我系统性地跑通了 PyTorch -> ONNX -> TensorRT (FP16 & INT8) 的整套部署流程。这篇文章是我对这次实践的复盘,提炼了一套可以直接复用的通用代码模板,希望能帮大家少走弯路。
核心内容与普适性步骤
整个加速流程主要分为四个阶段:导出 ONNX、FP16 编译、INT8 校准与编译、以及最终的推理验证。以下代码已去除具体的业务逻辑,只要替换你的模型和预处理代码即可直接运行。
第一步:PyTorch 导出 ONNX
安装 ONNX 环境:pip install onnx==1.15.0。
导出的核心是构造一个 Dummy Input,并冻结模型权重。在这里,我习惯开启常量折叠(do_constant_folding=True)来精简计算图。
Python
import torch
import onnx
def export_to_onnx(pytorch_model_path, onnx_save_path, input_shape=(1, 3, 224, 224)):
# 1. 初始化你的模型 (TODO: 替换为你自己的模型类)
model = MyModel()
# 2. 加载权重
state_dict = torch.load(pytorch_model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
# 3. 切换评估模式并冻结梯度
model.eval()
for p in model.parameters():
p.requires_grad = False
# 4. 构造 Dummy Input
dummy_input = torch.randn(*input_shape)
# 5. 导出 ONNX
torch.onnx.export(
model,
dummy_input,
onnx_save_path,
export_params=True,
opset_version=17, # 建议使用 17 或以上
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes=None # 如果需要动态 Batch 可以在这里设置
)
# 6. 校验模型合法性
onnx_model = onnx.load(onnx_save_path)
onnx.checker.check_model(onnx_model)
print(f"✅ ONNX 导出成功: {onnx_save_path}")
# 运行测试
# export_to_onnx("model.pth", "model.onnx")
第二步:使用 trtexec 编译 FP16 引擎
对于 FP16 模型,最快的方式是直接使用官方的 trtexec 工具。可以通过 Shell 脚本实现批量转换。为了保持控制台整洁,我通常会将海量的底层构建日志重定向隐藏掉。
Bash
#!/bin/bash
# 确保 trtexec 在你的环境变量中,或者指定绝对路径 (例如 /usr/src/tensorrt/bin/trtexec)
TRTEXEC_PATH="trtexec"
ONNX_FILE="./models/model.onnx"
ENGINE_FILE="./models/model_fp16.engine"
# 执行转换命令
$TRTEXEC_PATH \
--onnx="$ONNX_FILE" \
--fp16 \
--memPoolSize=workspace:2048MiB \
--saveEngine="$ENGINE_FILE" \
> /dev/null 2>&1 # 隐藏冗长的日志
if [ $? -eq 0 ]; then
echo "✅ 成功生成: $ENGINE_FILE"
else
echo "❌ 编译失败"
fi
你还可以使用以下命令直接测试生成的 engine 性能:
trtexec --loadEngine=model_fp16.engine --iterations=10 --warmUp=100 --duration=3
第三步:Python API 编译 INT8 引擎 (含校准器)
INT8 量化需要校准数据来统计张量的激活值分布。我们需要实现 trt.IInt8EntropyCalibrator2 接口。
核心坑点:校准器中的图像预处理必须与 PyTorch 训练时的预处理完全一致,否则量化后的精度会雪崩!
Python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import os
import glob
from PIL import Image
# 1. 定义校准器
class CustomCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, calib_dir, input_shape, cache_file="calib.cache"):
super().__init__()
self.cache_file = cache_file
self.image_files = glob.glob(os.path.join(calib_dir, "*.jpg"))[:50] # 选50张代表性图片
self.current_index = 0
self.input_shape = input_shape
self.input_size = int(np.prod(input_shape) * 4) # float32 bytes
self.device_input = cuda.mem_alloc(self.input_size)
def get_batch_size(self):
return self.input_shape[0]
def get_batch(self, names):
if self.current_index >= len(self.image_files):
return None
img_path = self.image_files[self.current_index]
self.current_index += 1
# TODO: 替换为你的严格预处理逻辑 (Resize, 归一化等)
img = Image.open(img_path).convert('RGB').resize((self.input_shape[2], self.input_shape[3]))
img_np = np.array(img).astype(np.float32) / 255.0
img_np = img_np.transpose(2, 0, 1)[np.newaxis, ...] # HWC to NCHW
# 拷贝到显存
cuda.memcpy_htod(self.device_input, np.ascontiguousarray(img_np))
return [int(self.device_input)]
def read_calibration_cache(self):
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
return None
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)
def __del__(self):
if hasattr(self, 'device_input'):
self.device_input.free()
# 2. 构建 INT8 Engine
def build_int8_engine(onnx_path, engine_path, calib_dir):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_path, 'rb') as f:
parser.parse(f.read())
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.INT8)
# TensorRT 10.x 使用 set_memory_pool_limit 设置 workspace
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 * (1 << 30)) # 2GB
calibrator = CustomCalibrator(calib_dir, input_shape=(1, 3, 224, 224))
config.int8_calibrator = calibrator
serialized_engine = builder.build_serialized_network(network, config)
with open(engine_path, 'wb') as f:
f.write(serialized_engine)
print(f"🎉 INT8 engine 构建成功: {engine_path}")
# 运行测试
# build_int8_engine("model.onnx", "model_int8.engine", "./calib_images/")
第四步:基于 PyCUDA 的通用推理代码
有了 .engine 文件后,我们就可以脱离 PyTorch,使用 pycuda 直接操作显存进行极速推理了。为了避免内存泄漏,我封装了一个通用类。
Python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
class TRTInfer:
def __init__(self, engine_path):
with open(engine_path, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
self.stream = cuda.Stream()
# 动态获取输入输出 shape (假设单输入单输出)
self.input_name = self.engine.get_tensor_name(0)
self.output_name = self.engine.get_tensor_name(1)
self.input_shape = self.engine.get_tensor_shape(self.input_name)
self.output_shape = self.engine.get_tensor_shape(self.output_name)
# 分配显存
self.d_input = cuda.mem_alloc(trt.volume(self.input_shape) * 4)
self.d_output = cuda.mem_alloc(trt.volume(self.output_shape) * 4)
# TensorRT 10.x 绑定内存地址
self.context.set_tensor_address(self.input_name, self.d_input)
self.context.set_tensor_address(self.output_name, self.d_output)
def infer(self, input_data):
"""异步推理核心代码"""
# HtoD: 主机到设备
cuda.memcpy_htod_async(self.d_input, np.ascontiguousarray(input_data).astype(np.float32), self.stream)
# 执行
self.context.execute_async_v3(self.stream.handle)
# DtoH: 设备到主机
output = np.empty(self.output_shape, dtype=np.float32)
cuda.memcpy_dtoh_async(output, self.d_output, self.stream)
self.stream.synchronize()
return output
def free(self):
"""强制释放显存,防止 OOM"""
self.d_input.free()
self.d_output.free()
del self.context
del self.engine
# 运行测试
# trt_model = TRTInfer("model_fp16.engine")
# dummy_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# out = trt_model.infer(dummy_data)
# trt_model.free()
遇到的问题与解决方法
-
批量处理时显存溢出 (OOM) 或卡死
- 现象:在循环测试多个 Engine 模型或者批量转换时,跑到一半直接进程崩溃。
- 解决:Python 的垃圾回收机制对 GPU 显存的释放并不及时。在每次推理或转换结束后,必须显式调用
free()方法释放 PyCUDA 分配的d_input和d_output,使用del删除 context 和 engine,并且加上gc.collect()强制回收。
-
INT8 精度下降严重
- 现象:FP16 推理结果完美,一转 INT8 就变成满屏噪点。
- 解决:检查
Calibrator中的get_batch方法。图像传入校准器前,必须经过与 PyTorch 训练集完全一模一样的处理逻辑(包括通道顺序 RGB/BGR,通道转换如 YCbCr,以及数值归一化/255.0等)。
-
TensorRT 10.x 接口变更问题
- 现象:旧版本的
set_workspace_size会报废弃警告或直接报错。 - 解决:TensorRT 10.x 中需要使用
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, size)来设置工作空间大小。同时,绑定内存地址需要使用set_tensor_address取代旧的bindings数组。
- 现象:旧版本的
收获与总结
从 PyTorch 无脑 forward() 切换到自己管理显存(cuda.memcpy_htod_async)、自己控制推理流(execute_async_v3),是一次对底层硬件运行逻辑的深度洗礼。
TensorRT 带来的 FPS 提升是立竿见影的。对于边缘设备开发者来说,掌握这套标准的 ONNX -> TRT 工作流,是让算法真正落地产生业务价值的关键技能。建议大家在部署初期养成良好的习惯:先单独跑通单张图的精度对齐,再去写批处理逻辑。