Compose: Android整合Yolo26e模型

0 阅读8分钟

还记得上一篇咱们聊的 Android整合Yolo模型 吗?当时用 TensorFlow Lite 在 Android 里整了个 YOLO 模型,但是留了个非常关键的问题没解决——

那就是:YOLO26 只能识别 80 种对象,那 80 种之外的东西咋办?总不能让手机变成"睁眼瞎"吧?

所以当时我提了两个后续整明白的事儿:

  1. 咋给模型做再训练?
  2. YOLO 家族的新成员 YOLOE(号称"实时感知一切"),这货能识别任何对象,比传统 YOLO 模型强多了,咋整合到 Android 里?

今天重点搞定第二个问题——如何把这个号称能识别万物的 YOLOE 模型塞进手机里,让APP也用明清灵水洗一洗眼睛!

要把 YOLOE 整到 Android 里,得先想明白俩事儿:

  1. 选哪个版本的模型?
  2. 用啥方法把模型整合到 Android 里?

1. YOLOE——万物识别小能手

YOLOE 这名字听着就牛逼,直译过来就是"实时感知一切"!这货专门为开放词汇表检测和分割而生,跟之前那些只能识别固定类别的 YOLO 模型完全不是一个量级——它能用文本、图像或者自带的词汇表当提示,实时识别任何你能想到的对象!

yoloe-seg-pf——懒人福音版

这里的"pf"是"prompt free"的意思,翻译过来就是"无提示词"。简单说就是,你啥都不用告诉它,直接扔张图片过去,它就能给你把里面的东西都认出来! 而且模型内置了4585种不同的对象类别,普通场景绝对够用。

from ultralytics import YOLOE

# 加载无提示词模型
model = YOLOE('yoloe-26l-seg-pf.pt')
# 直接预测图片
results = model.predict('bus.jpg')
# 显示结果
results[0].show()

YOLOE-seg——精准打击版

要是只想识别特定的东西(比如只看人和公交车),那就用 YOLOE-seg 模型,这货可以接受你指定的类别列表,精准定位你想看的东西,绝不浪费算力!

from ultralytics import YOLOE

# 加载标准模型
model = YOLOE('yoloe-26l-seg.pt')
# 设置只识别人和公交车
model.set_classes(['person', 'bus'])
# 开始预测
results = model.predict('bus.jpg')

2. 模型转换——踩坑记

上回整 YOLO 模型的时候,是把 PyTorch 模型转成 TensorFlow Lite 模型,当时可是花费了一坤日才学会的。

尝试转 TensorFlow Lite——失败

结果到 YOLOE 这儿,这招不灵了!直接转 TFLite 模型?门儿都没有! 转的时候直接报错,说什么"reshape 张量维度不匹配"。问了问千问才明白:

YOLOE 的 -seg 模型带了个实例分割头(mask head),里面用了动态 reshape,导出 ONNX 时没把维度固定死,导致 onnx2tf 转换时直接维度不匹配。

得,此路不通,试试换个法呗。

转向 ONNX——成功(但过程坎坷)

于是我想,不整这么复杂,ONNX 模型行不行?查了查官方文档,发现微软出的 ONNX Runtime 框架可以在 Android 上跑 ONNX 模型!

说干就干,转 ONNX 模型应该不难吧?就几行代码的事儿:

from ultralytics import YOLOE
model = YOLOE('yoloe-26l-seg-pf.pt')
model.export(format='onnx')

结果~~又报错了!千问给的解释是:

从你的模型名 YOLOE-26n-seg-pf.pt 可以看出:

  1. seg:表示支持实例分割;
  2. pf:表示 Prompt-Free 模式(即无需文本/视觉提示,自动检测所有物体)。 在 Prompt-Free 模式下,YOLOE 可能禁用了文本提示相关的分类头(text prompt head),导致 cls_head 或 bn_head 被设为 None。而当前导出流程中的 fuse() 函数 未正确处理这种“部分 head 缺失”的情况,直接对 None 做了 zip,从而崩溃。

解决方案就是,在导出ONNX模型的时候,禁用文本提示相关的分类头(text prompt head)。

def exportModel(modelname):
    model = YOLOE(modelname, task="detect")

    # 禁用整个模型的 fuse 行为
    original_fuse = getattr(model.model, 'fuse', None)
    if original_fuse is not None:
        model.model.fuse = lambda: model.model  # 返回自身,不 fuse
    
    # 同时禁用 head 的 fuse(双重保险)
    head = model.model.model[-1]
    if hasattr(head, 'fuse'):
        head.fuse = lambda *args, **kwargs: None
    
    model.export(format="onnx", half = True, dynamic=False, simplify=True)

对了,这里的 half=true 参数是为了减小模型大小,导出的是 FP16(float16)精度的模型,跟上次转 TFLite 模型时用的 float16 效果一样,省空间又不咋影响性能——简直是移动端的福音!

3. Android 中整合 ONNX Runtime——实战篇

第一步:添加依赖

先给项目加个 ONNX Runtime 依赖,基本上就像给手机装个插件一样简单:

libs.versions.toml

[versions]
onnxruntimeAndroid = "1.23.2"
onnxruntime-android = { module = "com.microsoft.onnxruntime:onnxruntime-android", version.ref = "onnxruntimeAndroid" }

build.gradle.kts

dependencies {
	implementation(libs.onnxruntime.android)
}

第二步:放模型文件

把转好的 ONNX 模型和分类文件丢到 assets 目录里:

对了,这里的 tag_list_chinese.txt 是从官方 GitHub 仓库下载的中文分类文件。

第三步:写识别代码

图像处理的部分上回已经聊过了,这次就不啰嗦了,直接上核心代码。

ONNX Runtime 用的是 OrtSession 来跑模型,咱整一个 OnnxYoloeModel 类:

OnnxYoloeModel.kt

package cn.mengfly.whereareyou.core.detect

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OnnxValue
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import android.content.Context
import android.graphics.Bitmap
import android.graphics.RectF
import cn.mengfly.whereareyou.core.loadClasses
import cn.mengfly.whereareyou.core.preProcessImage
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.nio.FloatBuffer

object OnnxYoloeModel : DetectModel {

    private const val INPUT = 640  // 输入图片大小
    private const val MODEL_PATH = "yoloe-26l-seg-pf.onnx"  // 模型路径
    private lateinit var session: OrtSession  // ONNX 会话
    private lateinit var env: OrtEnvironment  // ONNX 环境
    private lateinit var classes: List<String>  // 分类列表
    private const val CONF_THRESHOLD = 0.25f  // 置信度阈值

    /**
     * 模型是否初始化完成
     */
    override val isInit: Boolean
        get() = ::session.isInitialized

    override suspend fun init(context: Context) {
        // 加载分类文件
        classes = loadClasses(context, "tag_list_chinese.txt")
        // 加载 onnx 模型
        env = OrtEnvironment.getEnvironment();
        context.assets.open(MODEL_PATH).use {
            val readAllBytes = readAllBytes(it)
            session = env.createSession(readAllBytes, OrtSession.SessionOptions())
        }

    }

    override suspend fun detect(bitmap: Bitmap): List<DetectionResult> {

        // 预处理图像:缩放到 640x640,归一化到 [0, 1] 范围
        // 流程和上篇文章一样,就是把逻辑封装了一下
        val resized = bitmap.preProcessImage(INPUT)

        // 由于预处理后的图像数据为 HWC(height, width, channel)格式
        // 而ONNX模型的输入要求为(batch, channel, height, width)格式
        // 所以需要先将HWC转换为CHW(channel, height, width)格式
        val chwData = hwcToChw(resized.tensorBuffer.floatArray, INPUT, INPUT, 3)

        // 构建输入张量
        val inputTensor = OnnxTensor.createTensor(
            env,
            FloatBuffer.wrap(chwData),
            longArrayOf(1, 3, INPUT.toLong(), INPUT.toLong())
        )

        // 运行模型
        val inputs = mapOf<String, OnnxTensor>(
            session.inputNames.toList()[0] to inputTensor
        )
        val result = session.run(inputs, OrtSession.RunOptions())

        // 解析输出,过滤掉低置信度的结果
        return parseOutput(result[0])
            .applyNMS()  // 非极大值抑制,去除重叠的框
    }

    // HWC 转 CHW 格式的工具函数
    fun hwcToChw(hwcData: FloatArray, height: Int, width: Int, channels: Int): FloatArray {
        val chwData = FloatArray(hwcData.size)
        for (h in 0 until height) {
            for (w in 0 until width) {
                for (c in 0 until channels) {
                    val hwcIndex = (h * width + w) * channels + c
                    val chwIndex = c * height * width + h * width + w
                    chwData[chwIndex] = hwcData[hwcIndex]
                }
            }
        }
        return chwData
    }

    // 读取输入流的工具函数
    fun readAllBytes(inputStream: InputStream): ByteArray {
        val buffer = ByteArrayOutputStream()
        val data = ByteArray(1024)
        var bytesRead: Int
        while (inputStream.read(data).also { bytesRead = it } != -1) {
            buffer.write(data, 0, bytesRead)
        }
        return buffer.toByteArray()
    }

    /**
     * 解析模型输出
     * 注:这里只处理了检测框、分类和置信度,分割输出没处理
     */
    private fun parseOutput(output: OnnxValue): List<DetectionResult> {
        val tensor = output as OnnxTensor
        val detectResult = (tensor.value as Array<*>)[0] as Array<*>

        val result = mutableListOf<DetectionResult>()
        for (item in detectResult) {
            val detectRes = item as FloatArray

            // 提取检测信息
            val left = detectRes[0]
            val top = detectRes[1]
            val right = detectRes[2]
            val bottom = detectRes[3]
            val confidence = detectRes[4]
            val classType = detectRes[5].toInt()

            // 跳过低置信度的结果
            if (confidence < CONF_THRESHOLD) {
                continue
            }

            // 获取类别名称
            val classStr: String = if (classType >= classes.size) {
                "unknown"  // 未知类别
            } else {
                classes[classType]
            }

            // 添加到结果列表
            result.add(
                DetectionResult(
                    classStr,
                    confidence,
                    RectF(left, top, right, bottom)
                )
            )
        }
        return result
    }
}

整合识别结果

yoloe 虽然能识别的东西变多了,但是误识别的情况也跟着多了——有时候同一个东西,它能给你识别成好几种不同的物体,这就很尴尬了……

细心的伙伴应该注意到了,代码里调用了 applyNMS 方法,这玩意儿是干嘛的?

简单说,就是"非极大值抑制"——过滤掉那些重叠的、置信度低的检测框。但光有这还不够,我还做了个小优化:

如果两个不同类别的检测框几乎完全重合(IOU > 0.99),我就把置信度低的那个标记为"低可信度",显示的时候就能区分开了!

DetectionResult.kt


import android.graphics.RectF
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue

data class DetectionResult(
    val classType: String,  // 类别名称
    val confidence: Float,  // 置信度
    val boundingBox: RectF  // 检测框
) {
    var isSelected by mutableStateOf(false)  // 是否被选中显示
    var lowConfidence by mutableStateOf(false)  // 是否是低可信度结果
}


/**
 * 非极大值抑制(NMS):去除重叠的检测框与标记低可信度结果
 */
fun List<DetectionResult>.applyNMS(iouThreshold: Float = 0.45f): List<DetectionResult> {
    // 先按置信度从高到低排序
    val sorted = sortedByDescending { it.confidence }
    val keep = mutableListOf<DetectionResult>()

    for (curr in sorted) {
        var keepCurr = true
        for (kept in keep) {
            if (curr.classType == kept.classType) {
                // 相同类别,重叠度太高就干掉
                if (calculateIOU(curr.boundingBox, kept.boundingBox) > iouThreshold) {
                    keepCurr = false
                    break
                }
            } else {
                // 不同类别,但框几乎重合,标记为低可信度
                if (calculateIOU(curr.boundingBox, kept.boundingBox) > 0.99f) {
                    curr.lowConfidence = true
                    break
                }
            }
        }
        if (keepCurr) keep.add(curr)
    }
    return keep
}

/**
 * 计算交并比(IOU):判断两个框重叠程度
 */
private fun calculateIOU(box1: RectF, box2: RectF): Float {
    // 计算重叠区域的坐标
    val intersectLeft = maxOf(box1.left, box2.left)
    val intersectTop = maxOf(box1.top, box2.top)
    val intersectRight = minOf(box1.right, box2.right)
    val intersectBottom = minOf(box1.bottom, box2.bottom)
    
    // 计算重叠面积
    val intersectArea = 
        maxOf(0f, intersectRight - intersectLeft) * maxOf(0f, intersectBottom - intersectTop)
    
    // 计算两个框的总面积
    val box1Area = (box1.right - box1.left) * (box1.bottom - box1.top)
    val box2Area = (box2.right - box2.left) * (box2.bottom - box2.top)
    
    // 交并比 = 重叠面积 / (总面积 - 重叠面积)
    return intersectArea / (box1Area + box2Area - intersectArea)
}

4. 识别结果

至于具体怎么调用模型,上一篇已经聊得很详细了,这里就不啰嗦了。不过这次整合的时候,做了两个角度的优化:

  1. 智能显示:yoloe 识别的东西太多了,要是全显示出来页面得乱成一锅粥。所以改成了手动选择——想看哪个点哪个,清爽又方便!

  2. 模型切换:我把之前的 TensorFlow 模型也保留了,抽象了一个 DetectModel 接口,想切哪个模型就切哪个!

话不多说,直接看效果。

对了,模型和源码咱们已经打包上传到网盘了,具体链接就在我的这篇公众号文章里,需要的小伙伴自己去拿哈!

mp.weixin.qq.com/s/EfB4Gd3oS…


总结一下:这次成功把 YOLOE 这个号称"识别万物"的模型整到了 Android 里,解决了之前 YOLO 模型只能识别 80 种对象的问题。虽然过程中踩了不少坑,但最终效果还是挺不错的!

各位小伙伴要是有什么问题,欢迎在评论区留言,一起交流学习!