基于tensorflow.js,实现图像物体检测

827 阅读2分钟

COCO-SSD 是一个用于对象检测的模型,它旨在在单张图片中定位和识别多个对象。该模型能够检测 80 种类别的对象。TensorFlow.js 版本的 COCO-SSD 允许在浏览器中直接使用该模型,而无需了解机器学习的底层细节。

在浏览器中使用 TensorFlow.js 的 COCO-SSD 模型的基本步骤:

1. 安装依赖

npm i @tensorflow-models/coco-ssd
npm i @tensorflow/tfjs-backend-cpu
npm i @tensorflow/tfjs-backend-webgl

2. 加载模型

export interface ModelConfig {
  base?: ObjectDetectionBaseModel;
  modelUrl?: string;
}

cocoSsd.load(config: ModelConfig = {});

base: 控制基础卷积神经网络(CNN)模型的参数,该参数可以是 'mobilenet_v1'、'mobilenet_v2' 或 'lite_mobilenet_v2'。默认情况下,它设置为 'lite_mobilenet_v2'

  • lite_mobilenet_v2:这是三种模型中最小的,并且在推理速度上最快。它可能牺牲了一些精度以换取更小的模型大小和更快的运行速度,适合在移动设备或资源有限的环境中使用。
  • mobilenet_v2:这种模型在分类准确率上最高。它可能比 'lite_mobilenet_v2' 更大,并且推理速度可能稍慢,但提供了更高的准确性。
  • mobilenet_v1:这是 MobileNet 系列的第一个版本,它在模型大小和准确性之间提供了一个平衡。但在 'mobilenet_v2' 和 'lite_mobilenet_v2' 出现后,它可能不是最优的选择,除非有特定的需求或兼容性要求。

modelUrl: 指定模型自定义的URL,该url指向一个已经训练好的模型文件或模型服务的地址。可以将模型放在前端资源中,保证模型能够正常加载。

3. 检测物体

model.detect(
  img: tf.Tensor3D | ImageData | HTMLImageElement |
      HTMLCanvasElement | HTMLVideoElement, maxNumBoxes: number, minScore: number
)

检测结果:

[{
  bbox: [x, y, width, height],
  class: "person",
  score: 0.8380282521247864
}, {
  bbox: [x, y, width, height],
  class: "kite",
  score: 0.74644153267145157
}]

4. 示例

import '@tensorflow/tfjs-backend-cpu'
import '@tensorflow/tfjs-backend-webgl'
import * as cocoSsd from '@tensorflow-models/coco-ssd'

async function predictionImage(img){
    let img = new Image()
    img.crossOrigin = 'anonymous'
    img.src = img
    img.onload = async function () {
      const model = await cocoSsd.load({
        base: 'lite_mobilenet_v2',
        modelUrl: import.meta.env.VITE_BASE_URL + '/model/model.json'
      })
      let predictions = await model.detect(img)
    }
}

Screenity video - May 20, 2024.gif