还记得上一篇咱们聊的 Android整合Yolo模型 吗?当时用 TensorFlow Lite 在 Android 里整了个 YOLO 模型,但是留了个非常关键的问题没解决——
那就是:YOLO26 只能识别 80 种对象,那 80 种之外的东西咋办?总不能让手机变成"睁眼瞎"吧?
所以当时我提了两个后续整明白的事儿:
- 咋给模型做再训练?
- YOLO 家族的新成员 YOLOE(号称"实时感知一切"),这货能识别任何对象,比传统 YOLO 模型强多了,咋整合到 Android 里?
今天重点搞定第二个问题——如何把这个号称能识别万物的 YOLOE 模型塞进手机里,让APP也用明清灵水洗一洗眼睛!
要把 YOLOE 整到 Android 里,得先想明白俩事儿:
- 选哪个版本的模型?
- 用啥方法把模型整合到 Android 里?
1. YOLOE——万物识别小能手
YOLOE 这名字听着就牛逼,直译过来就是"实时感知一切"!这货专门为开放词汇表检测和分割而生,跟之前那些只能识别固定类别的 YOLO 模型完全不是一个量级——它能用文本、图像或者自带的词汇表当提示,实时识别任何你能想到的对象!
yoloe-seg-pf——懒人福音版
这里的"pf"是"prompt free"的意思,翻译过来就是"无提示词"。简单说就是,你啥都不用告诉它,直接扔张图片过去,它就能给你把里面的东西都认出来! 而且模型内置了4585种不同的对象类别,普通场景绝对够用。
from ultralytics import YOLOE
# 加载无提示词模型
model = YOLOE('yoloe-26l-seg-pf.pt')
# 直接预测图片
results = model.predict('bus.jpg')
# 显示结果
results[0].show()
YOLOE-seg——精准打击版
要是只想识别特定的东西(比如只看人和公交车),那就用 YOLOE-seg 模型,这货可以接受你指定的类别列表,精准定位你想看的东西,绝不浪费算力!
from ultralytics import YOLOE
# 加载标准模型
model = YOLOE('yoloe-26l-seg.pt')
# 设置只识别人和公交车
model.set_classes(['person', 'bus'])
# 开始预测
results = model.predict('bus.jpg')
2. 模型转换——踩坑记
上回整 YOLO 模型的时候,是把 PyTorch 模型转成 TensorFlow Lite 模型,当时可是花费了一坤日才学会的。
尝试转 TensorFlow Lite——失败
结果到 YOLOE 这儿,这招不灵了!直接转 TFLite 模型?门儿都没有! 转的时候直接报错,说什么"reshape 张量维度不匹配"。问了问千问才明白:
YOLOE 的 -seg 模型带了个实例分割头(mask head),里面用了动态 reshape,导出 ONNX 时没把维度固定死,导致 onnx2tf 转换时直接维度不匹配。
得,此路不通,试试换个法呗。
转向 ONNX——成功(但过程坎坷)
于是我想,不整这么复杂,ONNX 模型行不行?查了查官方文档,发现微软出的 ONNX Runtime 框架可以在 Android 上跑 ONNX 模型!
说干就干,转 ONNX 模型应该不难吧?就几行代码的事儿:
from ultralytics import YOLOE
model = YOLOE('yoloe-26l-seg-pf.pt')
model.export(format='onnx')
结果~~又报错了!千问给的解释是:
从你的模型名 YOLOE-26n-seg-pf.pt 可以看出:
- seg:表示支持实例分割;
- pf:表示 Prompt-Free 模式(即无需文本/视觉提示,自动检测所有物体)。 在 Prompt-Free 模式下,YOLOE 可能禁用了文本提示相关的分类头(text prompt head),导致 cls_head 或 bn_head 被设为 None。而当前导出流程中的 fuse() 函数 未正确处理这种“部分 head 缺失”的情况,直接对 None 做了 zip,从而崩溃。
解决方案就是,在导出ONNX模型的时候,禁用文本提示相关的分类头(text prompt head)。
def exportModel(modelname):
model = YOLOE(modelname, task="detect")
# 禁用整个模型的 fuse 行为
original_fuse = getattr(model.model, 'fuse', None)
if original_fuse is not None:
model.model.fuse = lambda: model.model # 返回自身,不 fuse
# 同时禁用 head 的 fuse(双重保险)
head = model.model.model[-1]
if hasattr(head, 'fuse'):
head.fuse = lambda *args, **kwargs: None
model.export(format="onnx", half = True, dynamic=False, simplify=True)
对了,这里的 half=true 参数是为了减小模型大小,导出的是 FP16(float16)精度的模型,跟上次转 TFLite 模型时用的 float16 效果一样,省空间又不咋影响性能——简直是移动端的福音!
3. Android 中整合 ONNX Runtime——实战篇
第一步:添加依赖
先给项目加个 ONNX Runtime 依赖,基本上就像给手机装个插件一样简单:
libs.versions.toml
[versions]
onnxruntimeAndroid = "1.23.2"
onnxruntime-android = { module = "com.microsoft.onnxruntime:onnxruntime-android", version.ref = "onnxruntimeAndroid" }
build.gradle.kts
dependencies {
implementation(libs.onnxruntime.android)
}
第二步:放模型文件
把转好的 ONNX 模型和分类文件丢到 assets 目录里:
对了,这里的 tag_list_chinese.txt 是从官方 GitHub 仓库下载的中文分类文件。
第三步:写识别代码
图像处理的部分上回已经聊过了,这次就不啰嗦了,直接上核心代码。
ONNX Runtime 用的是 OrtSession 来跑模型,咱整一个 OnnxYoloeModel 类:
OnnxYoloeModel.kt
package cn.mengfly.whereareyou.core.detect
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OnnxValue
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import android.content.Context
import android.graphics.Bitmap
import android.graphics.RectF
import cn.mengfly.whereareyou.core.loadClasses
import cn.mengfly.whereareyou.core.preProcessImage
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.nio.FloatBuffer
object OnnxYoloeModel : DetectModel {
private const val INPUT = 640 // 输入图片大小
private const val MODEL_PATH = "yoloe-26l-seg-pf.onnx" // 模型路径
private lateinit var session: OrtSession // ONNX 会话
private lateinit var env: OrtEnvironment // ONNX 环境
private lateinit var classes: List<String> // 分类列表
private const val CONF_THRESHOLD = 0.25f // 置信度阈值
/**
* 模型是否初始化完成
*/
override val isInit: Boolean
get() = ::session.isInitialized
override suspend fun init(context: Context) {
// 加载分类文件
classes = loadClasses(context, "tag_list_chinese.txt")
// 加载 onnx 模型
env = OrtEnvironment.getEnvironment();
context.assets.open(MODEL_PATH).use {
val readAllBytes = readAllBytes(it)
session = env.createSession(readAllBytes, OrtSession.SessionOptions())
}
}
override suspend fun detect(bitmap: Bitmap): List<DetectionResult> {
// 预处理图像:缩放到 640x640,归一化到 [0, 1] 范围
// 流程和上篇文章一样,就是把逻辑封装了一下
val resized = bitmap.preProcessImage(INPUT)
// 由于预处理后的图像数据为 HWC(height, width, channel)格式
// 而ONNX模型的输入要求为(batch, channel, height, width)格式
// 所以需要先将HWC转换为CHW(channel, height, width)格式
val chwData = hwcToChw(resized.tensorBuffer.floatArray, INPUT, INPUT, 3)
// 构建输入张量
val inputTensor = OnnxTensor.createTensor(
env,
FloatBuffer.wrap(chwData),
longArrayOf(1, 3, INPUT.toLong(), INPUT.toLong())
)
// 运行模型
val inputs = mapOf<String, OnnxTensor>(
session.inputNames.toList()[0] to inputTensor
)
val result = session.run(inputs, OrtSession.RunOptions())
// 解析输出,过滤掉低置信度的结果
return parseOutput(result[0])
.applyNMS() // 非极大值抑制,去除重叠的框
}
// HWC 转 CHW 格式的工具函数
fun hwcToChw(hwcData: FloatArray, height: Int, width: Int, channels: Int): FloatArray {
val chwData = FloatArray(hwcData.size)
for (h in 0 until height) {
for (w in 0 until width) {
for (c in 0 until channels) {
val hwcIndex = (h * width + w) * channels + c
val chwIndex = c * height * width + h * width + w
chwData[chwIndex] = hwcData[hwcIndex]
}
}
}
return chwData
}
// 读取输入流的工具函数
fun readAllBytes(inputStream: InputStream): ByteArray {
val buffer = ByteArrayOutputStream()
val data = ByteArray(1024)
var bytesRead: Int
while (inputStream.read(data).also { bytesRead = it } != -1) {
buffer.write(data, 0, bytesRead)
}
return buffer.toByteArray()
}
/**
* 解析模型输出
* 注:这里只处理了检测框、分类和置信度,分割输出没处理
*/
private fun parseOutput(output: OnnxValue): List<DetectionResult> {
val tensor = output as OnnxTensor
val detectResult = (tensor.value as Array<*>)[0] as Array<*>
val result = mutableListOf<DetectionResult>()
for (item in detectResult) {
val detectRes = item as FloatArray
// 提取检测信息
val left = detectRes[0]
val top = detectRes[1]
val right = detectRes[2]
val bottom = detectRes[3]
val confidence = detectRes[4]
val classType = detectRes[5].toInt()
// 跳过低置信度的结果
if (confidence < CONF_THRESHOLD) {
continue
}
// 获取类别名称
val classStr: String = if (classType >= classes.size) {
"unknown" // 未知类别
} else {
classes[classType]
}
// 添加到结果列表
result.add(
DetectionResult(
classStr,
confidence,
RectF(left, top, right, bottom)
)
)
}
return result
}
}
整合识别结果
yoloe 虽然能识别的东西变多了,但是误识别的情况也跟着多了——有时候同一个东西,它能给你识别成好几种不同的物体,这就很尴尬了……
细心的伙伴应该注意到了,代码里调用了 applyNMS 方法,这玩意儿是干嘛的?
简单说,就是"非极大值抑制"——过滤掉那些重叠的、置信度低的检测框。但光有这还不够,我还做了个小优化:
如果两个不同类别的检测框几乎完全重合(IOU > 0.99),我就把置信度低的那个标记为"低可信度",显示的时候就能区分开了!
DetectionResult.kt
import android.graphics.RectF
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
data class DetectionResult(
val classType: String, // 类别名称
val confidence: Float, // 置信度
val boundingBox: RectF // 检测框
) {
var isSelected by mutableStateOf(false) // 是否被选中显示
var lowConfidence by mutableStateOf(false) // 是否是低可信度结果
}
/**
* 非极大值抑制(NMS):去除重叠的检测框与标记低可信度结果
*/
fun List<DetectionResult>.applyNMS(iouThreshold: Float = 0.45f): List<DetectionResult> {
// 先按置信度从高到低排序
val sorted = sortedByDescending { it.confidence }
val keep = mutableListOf<DetectionResult>()
for (curr in sorted) {
var keepCurr = true
for (kept in keep) {
if (curr.classType == kept.classType) {
// 相同类别,重叠度太高就干掉
if (calculateIOU(curr.boundingBox, kept.boundingBox) > iouThreshold) {
keepCurr = false
break
}
} else {
// 不同类别,但框几乎重合,标记为低可信度
if (calculateIOU(curr.boundingBox, kept.boundingBox) > 0.99f) {
curr.lowConfidence = true
break
}
}
}
if (keepCurr) keep.add(curr)
}
return keep
}
/**
* 计算交并比(IOU):判断两个框重叠程度
*/
private fun calculateIOU(box1: RectF, box2: RectF): Float {
// 计算重叠区域的坐标
val intersectLeft = maxOf(box1.left, box2.left)
val intersectTop = maxOf(box1.top, box2.top)
val intersectRight = minOf(box1.right, box2.right)
val intersectBottom = minOf(box1.bottom, box2.bottom)
// 计算重叠面积
val intersectArea =
maxOf(0f, intersectRight - intersectLeft) * maxOf(0f, intersectBottom - intersectTop)
// 计算两个框的总面积
val box1Area = (box1.right - box1.left) * (box1.bottom - box1.top)
val box2Area = (box2.right - box2.left) * (box2.bottom - box2.top)
// 交并比 = 重叠面积 / (总面积 - 重叠面积)
return intersectArea / (box1Area + box2Area - intersectArea)
}
4. 识别结果
至于具体怎么调用模型,上一篇已经聊得很详细了,这里就不啰嗦了。不过这次整合的时候,做了两个角度的优化:
-
智能显示:yoloe 识别的东西太多了,要是全显示出来页面得乱成一锅粥。所以改成了手动选择——想看哪个点哪个,清爽又方便!
-
模型切换:我把之前的 TensorFlow 模型也保留了,抽象了一个
DetectModel接口,想切哪个模型就切哪个!
话不多说,直接看效果。
对了,模型和源码咱们已经打包上传到网盘了,具体链接就在我的这篇公众号文章里,需要的小伙伴自己去拿哈!
总结一下:这次成功把 YOLOE 这个号称"识别万物"的模型整到了 Android 里,解决了之前 YOLO 模型只能识别 80 种对象的问题。虽然过程中踩了不少坑,但最终效果还是挺不错的!
各位小伙伴要是有什么问题,欢迎在评论区留言,一起交流学习!