COCO-SSD 是一个用于对象检测的模型,它旨在在单张图片中定位和识别多个对象。该模型能够检测 80 种类别的对象。TensorFlow.js 版本的 COCO-SSD 允许在浏览器中直接使用该模型,而无需了解机器学习的底层细节。
在浏览器中使用 TensorFlow.js 的 COCO-SSD 模型的基本步骤:
1. 安装依赖
npm i @tensorflow-models/coco-ssd
npm i @tensorflow/tfjs-backend-cpu
npm i @tensorflow/tfjs-backend-webgl
2. 加载模型
export interface ModelConfig {
base?: ObjectDetectionBaseModel;
modelUrl?: string;
}
cocoSsd.load(config: ModelConfig = {});
base: 控制基础卷积神经网络(CNN)模型的参数,该参数可以是 'mobilenet_v1'、'mobilenet_v2' 或 'lite_mobilenet_v2'。默认情况下,它设置为 'lite_mobilenet_v2'
- lite_mobilenet_v2:这是三种模型中最小的,并且在推理速度上最快。它可能牺牲了一些精度以换取更小的模型大小和更快的运行速度,适合在移动设备或资源有限的环境中使用。
- mobilenet_v2:这种模型在分类准确率上最高。它可能比 'lite_mobilenet_v2' 更大,并且推理速度可能稍慢,但提供了更高的准确性。
- mobilenet_v1:这是 MobileNet 系列的第一个版本,它在模型大小和准确性之间提供了一个平衡。但在 'mobilenet_v2' 和 'lite_mobilenet_v2' 出现后,它可能不是最优的选择,除非有特定的需求或兼容性要求。
modelUrl: 指定模型自定义的URL,该url指向一个已经训练好的模型文件或模型服务的地址。可以将模型放在前端资源中,保证模型能够正常加载。
3. 检测物体
model.detect(
img: tf.Tensor3D | ImageData | HTMLImageElement |
HTMLCanvasElement | HTMLVideoElement, maxNumBoxes: number, minScore: number
)
检测结果:
[{
bbox: [x, y, width, height],
class: "person",
score: 0.8380282521247864
}, {
bbox: [x, y, width, height],
class: "kite",
score: 0.74644153267145157
}]
4. 示例
import '@tensorflow/tfjs-backend-cpu'
import '@tensorflow/tfjs-backend-webgl'
import * as cocoSsd from '@tensorflow-models/coco-ssd'
async function predictionImage(img){
let img = new Image()
img.crossOrigin = 'anonymous'
img.src = img
img.onload = async function () {
const model = await cocoSsd.load({
base: 'lite_mobilenet_v2',
modelUrl: import.meta.env.VITE_BASE_URL + '/model/model.json'
})
let predictions = await model.detect(img)
}
}