随着AI的发展,在移动端上使用本地模型已经不是新鲜事儿,最近在Android端上尝试使用onnx模型实时识别车牌。图像的获取采用Android CemeraX + OpengGL架构,这一部分不做详细介绍,已经是很成熟的技术架构。
前期准备
Android端使用onnx模型,必须引入官方提供的包:
implementation("com.microsoft.onnxruntime:onnxruntime-android:1.22.0")
将onnx模型文件放在项目模块的res/raw目录下,直接在代码里使用下列代码进行模型文件的访问,会使用到OrtEnvironment类,OrtSession类:
this.env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Load model files from the res/raw directory
byte[] detectionModel = readModelFromRaw(context, R.raw.yolov7plate);
byte[] recognitionModel = readModelFromRaw(context, R.raw.plate_recognition_color);
// Create inference sessions
this.detectionSession = env.createSession(detectionModel, opts);
this.recognitionSession = env.createSession(recognitionModel, opts);
模型可以看作传统代码里的一个方法,需要一个输出,经过模型内部计算,得到一个输出:
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer, inputShape);
// Run inference
// 向模型中输入的时候,需要先了解模型的输入由哪些数据组成
OrtSession.Result result = detectionSession.run(Collections.singletonMap("images", inputTensor));
在使用模型的时候,需要先了解模型的输入和输出是由什么数据组成的,就像我们日常调用一个方法一样。
模型的输入
通过查看模型文件的结构,发现在这个检测模型yolov7plate.onnx中,它的输入要求标识符维“images”,数据类型是float32[1,3,640,640]:
数据格式含义
- 1 张 RGB 彩色图像
- 大小为 640×640 像素
- 数据类型是 float(通常是 float32)
就是说张量的形状(tensor shape)是一张图像通道为3(RGB不是ARGB),尺寸是640x640的图像。其中“images”这个key是ONNX模型内部定义的输入标识符。你必须使用与模型文件定义完全一致的名称,否则ONNX Runtime会因为找不到对应的输入节点而报错。
同时,当ONNX Runtime准备执行模型的第一层计算时,它会进行一个形状检查,发现你提供的数据形状 [..., 640, 480] 与模型期望的 [..., 640, 640] 在最后一个维度上不匹配,于是就会立刻抛出一个 OrtException 或类似的错误,错误信息通常会包含 "Shape mismatch" 或 "Dimension mismatch" 等字样。
为什么模型要求这么严格?
神经网络的结构是刚性的。模型内部的权重、卷积核、全连接层等所有计算单元的尺寸都是在训练时根据640x640这个输入尺寸被固定下来的。
神经网络的结构是刚性的。模型内部的权重、卷积核、全连接层等所有计算单元的尺寸都是在训练时根据640x640这个输入尺寸被固定下来的。
- 卷积层:虽然卷积操作本身对输入尺寸有一定的灵活性,但后续的层不一定。
- 全连接层 (Fully Connected Layer) :在很多模型的末端,特征图会被“展平”(Flatten)成一个一维向量,然后送入全连接层。这个展平后的一维向量的长度是严格依赖于输入特征图的尺寸的。如果输入尺寸不对,向量长度就会错,导致后续所有矩阵乘法全部失败。
输入前处理
但是,通常我们拍照的图片尺寸并不一定是640x640大小,所以需要我们对原始图片进行处理,这个处理的过程被称之为:LetterBoxing( Letterbox是一种图像调整技术,通过在图像的边缘填充一定数量的像素,使图像的尺寸满足特定的要求。)
数据格式含义
第一个维度 1:批处理大小 (Batch Size)
- 含义:这代表你一次性输入给模型进行推理的图片数量。
- 在这里:1 表示模型每次只处理一张图片。这是在移动端进行实时推理时最常见的情况。如果你一次性输入4张图片进行处理,这个维度就会是 4。
第二个维度 25200:预测框的总数量 (Total Number of Predictions)
-
含义:这可能是最让人困惑的部分。模型并不会直接告诉你“图里有3个物体”。相反,它会在图片上预设的成千上万个“锚点”(Anchors)或“网格点”(Grid cells)上都进行一次预测。25200 就是所有这些预测的总和。
-
来源:这个数字通常是由模型内部不同尺寸的特征图(feature map)上的预测数量相加得来的。例如,一个典型的YOLO模型可能会在三个不同尺度的网格上进行预测:
- 一个 80x80 的大特征图(用于检测小物体)
- 一个 40x40 的中特征图(用于检测中等物体)
- 一个 20x20 的小特征图(用于检测大物体)
- 如果每个网格单元都关联3个不同形状的锚框(anchor boxes),那么总预测数就是:
(80 * 80 * 3) + (40 * 40 * 3) + (20 * 20 * 3) = 19200 + 4800 + 1200 = 25200
-
关键点:这 25200 个预测中,绝大部分都是无用的(背景),你需要通过后续处理来筛选出真正有效的预测。
第三个维度 19:每个预测框的详细信息 (Prediction Vector)
-
含义:这表示对于 25200 个预测中的每一个,模型都输出了一个包含 19 个浮点数的向量,用来描述这个预测。
-
这19个值的具体构成:
[center_x, center_y, width, height, objectness_confidence, class_score_1, class_score_2, ..., class_score_14]我们来分解一下:
-
前4个值 (center_x, center_y, width, height) : 这是预测的“边界框(Bounding Box) ”信息。
- center_x, center_y: 边界框的中心点 x, y 坐标。
- width, height: 边界框的宽度和高度。
- 注意:这些值通常是归一化的(即值在0到1之间),你需要将它们乘以原始图片(输入时的图片)的宽度和高度才能得到实际的像素坐标。
-
第5个值 (objectness_confidence) : 这是置信度分数。
- 它表示模型认为这个边界框内存在一个物体的可能性有多大。分数越接近1,表示模型越确定这里面有东西;越接近0,表示模型认为这里是背景。这是你进行初步筛选最关键的依据。
-
剩下的14个值 (class_score_1 到 class_score_14) : 这是类别分数。
- 这里的 14 是通过计算 19 - 4 (边界框) - 1 (置信度) = 14 得来的。
- 这意味你的模型被训练来识别14个不同的物体类别。
- 这14个值分别代表这个预测框是“类别1”的概率、“类别2”的概率……以此类推。你需要在这14个分数中找到最高分,该最高分对应的类别就是模型对这个物体的预测类别。
-
输出后处理 (Post-processing)
理解了这个输出结构后,你的代码需要做的后处理工作流程如下:
- 遍历所有预测:遍历这 25200 个预测行。
- 过滤置信度:对于每一行,检查第5个值(objectness_confidence)。如果这个值低于一个你设定的阈值(例如 0.5),就直接忽略这一行,因为它很可能是背景。
- 获取类别和类别分数:对于通过了置信度过滤的行,查看后面的14个类别分数。找到分数最高的那个,这个分数就是类别置信度,它的索引位置就代表了预测的类别。
- 解码边界框:获取前4个值,并将它们乘以图片的实际宽高,得到在图片上的像素坐标和尺寸。
- 应用非极大值抑制 (NMS) :经过以上步骤后,你可能会对同一个物体得到多个重叠的边界框。NMS是一种标准算法,它会帮你消除这些重复的框,只保留置信度最高的那个。
/**
* Non-maximum suppression (NMS) algorithm.
* @param boxes
* @param iouThreshold
* @return
*/
private static List<BoundingBox> applyNms(List<BoundingBox> boxes, float iouThreshold) {
boxes.sort(Comparator.comparing(b -> -b.confidence));
List<BoundingBox> selectedBoxes = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox bestBox = boxes.get(0);
selectedBoxes.add(bestBox);
boxes.remove(0);
boxes.removeIf(box -> calculateIoU(bestBox, box) > iouThreshold);
}
return selectedBoxes;
}
总结
这么看来,在移动端使用本地模型,其实可以把它当做一个方法,只不过这个方法内部很庞大,前期要对模型的输入和输出数据格式进行了解,前期输入需要将数据处理成模型能够理解的格式,后期需要将输出的数据处理成传统程序需要的数据格式。
扩展:
何为张量?
你可以理解为是一个多维数组,算是深度学习框架里最基本的数据结构。
0维张量,只有一个数,可以称之为标量;
1维张量,是一维数组[1,2,3],可以理解为向量;
以此类推,2维张量,是二维数组[[1,2,3],[4,5,6]],可以理解为矩阵, 高维张量,例如三维可以代表立方体的数据,四维常见于图像批次数据。
像上文提到的模型的输入:
-
图片
- 一张彩色图像 = [3, H, W] 张量 (3 通道,H 高度,W 宽度)。
- 一批图像 = [N, 3, H, W] 张量(N 是 batch size)。
这张 3D 示意图展示了 YOLOv7 输入张量 [1, 3, H, W] 的结构:
- Z 轴方向:三个通道(R、G、B),相当于三层矩阵。
- X、Y 平面:表示图像的宽度 (W) 和高度 (H)。
- Batch=1:这里只有一张图像,所以最外层 batch 维度是 1。
把三层(RGB)叠在一起,就能构成一张彩色图片的数据表示,正是模型需要的输入。
边界框格式
两种主要的边界框格式
1. 角点格式 (Corner Format): (x1, y1, x2, y2)
- x1, y1: 矩形 左上角 的坐标。
- x2, y2: 矩形 右下角 的坐标。
- 这是您代码目前假设的格式,也是安卓 RectF 类使用的格式。
2. 中心点格式 (Center Format): (center_x, center_y, width, height) 或 (cx, cy, w, h)
- center_x, center_y: 矩形 中心点 的坐标。
- width, height: 矩形的 宽度 和 高度。
- 这是YOLO系列模型的标准输出格式。因为YOLO的原理就是在网格单元格(grid cell)的中心附近去预测物体中心,所以它的原生输出就是中心点格式。
如何确定模型的输出格式?
- 查阅模型文档(最佳途径) :如果您知道这个ONNX模型是从哪个项目(如某个GitHub仓库)来的,去查阅它的文档或Python源代码。特别是后处理部分(post-processing),肯定会有代码将模型的原始输出解码成边界框,一看便知。
- 根据模型类型推断(非常可靠) :正如我们之前讨论的,[1, 25200, 19] 这样的输出形状是YOLOv5/v7等模型的典型特征。这些模型的原生输出都是中心点格式 (cx, cy, w, h)。
- 逻辑推断和实验:鉴于您目前遇到的所有奇怪的坐标问题(负数、0、天文数字),最合理的解释就是您错误地将中心点坐标当成了角点坐标来处理。
特别感谢:
感谢开源项目: github.com/pcb9382/Pla…,使用的onnx模型来自于这里。
同时非常感谢一起调研的同事,团队协作的power。
对AI时代的生活美学、技术和产品探讨感兴趣的可以关注公众号