如何在Android上使用ONNX模型?

557 阅读10分钟

随着AI的发展,在移动端上使用本地模型已经不是新鲜事儿,最近在Android端上尝试使用onnx模型实时识别车牌。图像的获取采用Android CemeraX + OpengGL架构,这一部分不做详细介绍,已经是很成熟的技术架构。

前期准备

Android端使用onnx模型,必须引入官方提供的包:

implementation("com.microsoft.onnxruntime:onnxruntime-android:1.22.0")

将onnx模型文件放在项目模块的res/raw目录下,直接在代码里使用下列代码进行模型文件的访问,会使用到OrtEnvironment类,OrtSession类:

        this.env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
 
        // Load model files from the res/raw directory
        byte[] detectionModel = readModelFromRaw(context, R.raw.yolov7plate);
        byte[] recognitionModel = readModelFromRaw(context, R.raw.plate_recognition_color);
 
        // Create inference sessions
        this.detectionSession = env.createSession(detectionModel, opts);
        this.recognitionSession = env.createSession(recognitionModel, opts);

模型可以看作传统代码里的一个方法,需要一个输出,经过模型内部计算,得到一个输出:

        OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer, inputShape);
        // Run inference
        // 向模型中输入的时候,需要先了解模型的输入由哪些数据组成
        OrtSession.Result result = detectionSession.run(Collections.singletonMap("images", inputTensor));
       

在使用模型的时候,需要先了解模型的输入和输出是由什么数据组成的,就像我们日常调用一个方法一样。

60b223ce38b949e3aa0230cb704c8c08.png

模型的输入

通过查看模型文件的结构,发现在这个检测模型yolov7plate.onnx中,它的输入要求标识符维“images”,数据类型是float32[1,3,640,640]:

数据格式含义

  • 1 张 RGB 彩色图像
  • 大小为 640×640 像素
  • 数据类型是 float(通常是 float32)

就是说张量的形状(tensor shape)是一张图像通道为3(RGB不是ARGB),尺寸是640x640的图像。其中“images”这个key是ONNX模型内部定义的输入标识符。你必须使用与模型文件定义完全一致的名称,否则ONNX Runtime会因为找不到对应的输入节点而报错。

同时,当ONNX Runtime准备执行模型的第一层计算时,它会进行一个形状检查,发现你提供的数据形状 [..., 640, 480] 与模型期望的 [..., 640, 640] 在最后一个维度上不匹配,于是就会立刻抛出一个 OrtException 或类似的错误,错误信息通常会包含 "Shape mismatch" 或 "Dimension mismatch" 等字样。

为什么模型要求这么严格?

神经网络的结构是刚性的。模型内部的权重、卷积核、全连接层等所有计算单元的尺寸都是在训练时根据640x640这个输入尺寸被固定下来的。

神经网络的结构是刚性的。模型内部的权重、卷积核、全连接层等所有计算单元的尺寸都是在训练时根据640x640这个输入尺寸被固定下来的。

  • 卷积层:虽然卷积操作本身对输入尺寸有一定的灵活性,但后续的层不一定。
  • 全连接层 (Fully Connected Layer) :在很多模型的末端,特征图会被“展平”(Flatten)成一个一维向量,然后送入全连接层。这个展平后的一维向量的长度是严格依赖于输入特征图的尺寸的。如果输入尺寸不对,向量长度就会错,导致后续所有矩阵乘法全部失败。

输入前处理

但是,通常我们拍照的图片尺寸并不一定是640x640大小,所以需要我们对原始图片进行处理,这个处理的过程被称之为:LetterBoxing( Letterbox是一种图像调整技术,通过在图像的边缘填充一定数量的像素,使图像的尺寸满足特定的要求。)

数据格式含义

第一个维度 1:批处理大小 (Batch Size)
  • 含义:这代表你一次性输入给模型进行推理的图片数量。
  • 在这里:1 表示模型每次只处理一张图片。这是在移动端进行实时推理时最常见的情况。如果你一次性输入4张图片进行处理,这个维度就会是 4。
第二个维度 25200:预测框的总数量 (Total Number of Predictions)
  • 含义:这可能是最让人困惑的部分。模型并不会直接告诉你“图里有3个物体”。相反,它会在图片上预设的成千上万个“锚点”(Anchors)或“网格点”(Grid cells)上都进行一次预测。25200 就是所有这些预测的总和。

  • 来源:这个数字通常是由模型内部不同尺寸的特征图(feature map)上的预测数量相加得来的。例如,一个典型的YOLO模型可能会在三个不同尺度的网格上进行预测:

    • 一个 80x80 的大特征图(用于检测小物体)
    • 一个 40x40 的中特征图(用于检测中等物体)
    • 一个 20x20 的小特征图(用于检测大物体)
    • 如果每个网格单元都关联3个不同形状的锚框(anchor boxes),那么总预测数就是:
      (80 * 80 * 3) + (40 * 40 * 3) + (20 * 20 * 3) = 19200 + 4800 + 1200 = 25200
  • 关键点:这 25200 个预测中,绝大部分都是无用的(背景),你需要通过后续处理来筛选出真正有效的预测。

第三个维度 19:每个预测框的详细信息 (Prediction Vector)
  • 含义:这表示对于 25200 个预测中的每一个,模型都输出了一个包含 19 个浮点数的向量,用来描述这个预测。

  • 这19个值的具体构成
    [center_x, center_y, width, height, objectness_confidence, class_score_1, class_score_2, ..., class_score_14]

    我们来分解一下:

    • 前4个值 (center_x, center_y, width, height) : 这是预测的“边界框(Bounding Box) ”信息。

      • center_x, center_y: 边界框的中心点 x, y 坐标。
      • width, height: 边界框的宽度和高度。
      • 注意:这些值通常是归一化的(即值在0到1之间),你需要将它们乘以原始图片(输入时的图片)的宽度和高度才能得到实际的像素坐标。
    • 第5个值 (objectness_confidence) : 这是置信度分数

      • 它表示模型认为这个边界框内存在一个物体的可能性有多大。分数越接近1,表示模型越确定这里面有东西;越接近0,表示模型认为这里是背景。这是你进行初步筛选最关键的依据。
    • 剩下的14个值 (class_score_1 到 class_score_14) : 这是类别分数

      • 这里的 14 是通过计算 19 - 4 (边界框) - 1 (置信度) = 14 得来的。
      • 这意味你的模型被训练来识别14个不同的物体类别
      • 这14个值分别代表这个预测框是“类别1”的概率、“类别2”的概率……以此类推。你需要在这14个分数中找到最高分,该最高分对应的类别就是模型对这个物体的预测类别。

输出后处理 (Post-processing)

       理解了这个输出结构后,你的代码需要做的后处理工作流程如下:

  1. 遍历所有预测:遍历这 25200 个预测行。
  2. 过滤置信度:对于每一行,检查第5个值(objectness_confidence)。如果这个值低于一个你设定的阈值(例如 0.5),就直接忽略这一行,因为它很可能是背景。
  3. 获取类别和类别分数:对于通过了置信度过滤的行,查看后面的14个类别分数。找到分数最高的那个,这个分数就是类别置信度,它的索引位置就代表了预测的类别。
  4. 解码边界框:获取前4个值,并将它们乘以图片的实际宽高,得到在图片上的像素坐标和尺寸。
  5. 应用非极大值抑制 (NMS) :经过以上步骤后,你可能会对同一个物体得到多个重叠的边界框。NMS是一种标准算法,它会帮你消除这些重复的框,只保留置信度最高的那个。
/**
* Non-maximum suppression (NMS) algorithm.
* @param boxes
* @param iouThreshold
* @return
*/
private static List<BoundingBox> applyNms(List<BoundingBox> boxes, float iouThreshold) {
        boxes.sort(Comparator.comparing(b -> -b.confidence));
        List<BoundingBox> selectedBoxes = new ArrayList<>();

        while (!boxes.isEmpty()) {
            BoundingBox bestBox = boxes.get(0);
            selectedBoxes.add(bestBox);
            boxes.remove(0);

            boxes.removeIf(box -> calculateIoU(bestBox, box) > iouThreshold);
        }
        return selectedBoxes;
    }

总结

        这么看来,在移动端使用本地模型,其实可以把它当做一个方法,只不过这个方法内部很庞大,前期要对模型的输入和输出数据格式进行了解,前期输入需要将数据处理成模型能够理解的格式,后期需要将输出的数据处理成传统程序需要的数据格式。

扩展:

何为张量?

你可以理解为是一个多维数组,算是深度学习框架里最基本的数据结构。

0维张量,只有一个数,可以称之为标量

1维张量,是一维数组[1,2,3],可以理解为向量

以此类推,2维张量,是二维数组[[1,2,3],[4,5,6]],可以理解为矩阵, 高维张量,例如三维可以代表立方体的数据,四维常见于图像批次数据。

像上文提到的模型的输入:

  • 图片

    • 一张彩色图像 = [3, H, W] 张量 (3 通道,H 高度,W 宽度)。
    • 一批图像 = [N, 3, H, W] 张量(N 是 batch size)。

未命名项目.jpeg

这张 3D 示意图展示了 YOLOv7 输入张量 [1, 3, H, W] 的结构:

  • Z 轴方向:三个通道(R、G、B),相当于三层矩阵。
  • X、Y 平面:表示图像的宽度 (W) 和高度 (H)。
  • Batch=1:这里只有一张图像,所以最外层 batch 维度是 1。

把三层(RGB)叠在一起,就能构成一张彩色图片的数据表示,正是模型需要的输入。

边界框格式

两种主要的边界框格式

1. 角点格式 (Corner Format): (x1, y1, x2, y2)
  • x1, y1: 矩形 左上角 的坐标。
  • x2, y2: 矩形 右下角 的坐标。
  • 这是您代码目前假设的格式,也是安卓 RectF 类使用的格式。
2. 中心点格式 (Center Format): (center_x, center_y, width, height) 或 (cx, cy, w, h)
  • center_x, center_y: 矩形 中心点 的坐标。
  • width, height: 矩形的 宽度 和 高度
  • 这是YOLO系列模型的标准输出格式。因为YOLO的原理就是在网格单元格(grid cell)的中心附近去预测物体中心,所以它的原生输出就是中心点格式。

如何确定模型的输出格式?

  1. 查阅模型文档(最佳途径) :如果您知道这个ONNX模型是从哪个项目(如某个GitHub仓库)来的,去查阅它的文档或Python源代码。特别是后处理部分(post-processing),肯定会有代码将模型的原始输出解码成边界框,一看便知。
  2. 根据模型类型推断(非常可靠) :正如我们之前讨论的,[1, 25200, 19] 这样的输出形状是YOLOv5/v7等模型的典型特征。这些模型的原生输出都是中心点格式 (cx, cy, w, h)。
  3. 逻辑推断和实验:鉴于您目前遇到的所有奇怪的坐标问题(负数、0、天文数字),最合理的解释就是您错误地将中心点坐标当成了角点坐标来处理。

特别感谢:

感谢开源项目: github.com/pcb9382/Pla…,使用的onnx模型来自于这里。

同时非常感谢一起调研的同事,团队协作的power。

对AI时代的生活美学、技术和产品探讨感兴趣的可以关注公众号

扫码_搜索联合传播样式-白色版.png