在vue3中使用onnxruntime-web运行yolov8n.onnx模型

497 阅读5分钟

YOLO(You Only Look Once)是一个广泛用于目标检测的高效模型。YOLOv8 是该系列的最新版本之一,提供了极快的速度和较高的精度。在这篇文章中,我们将通过 onnxruntime-web 来加载和运行 YOLOv8n(YOLOv8 nano)ONNX 模型,并在 Vue 3 中展示检测结果。

1. 准备工作

在开始之前,请确保你已经具备以下环境和依赖项:

  • Node.jsnpm 已安装。
  • 基本的 Vue 3 项目。

已导出的 yolov8n.onnx 模型(可以从 这里 获取)。

  • 安装 onnxruntime-web 库。( Version 1.20.1

安装 onnxruntime-web

在 Vue 项目的根目录下,安装 onnxruntime-web

npm install onnxruntime-web

相关链接

www.npmjs.com/package/onn…
onnxruntime.ai/docs/tutori…
github.com/Microsoft/o… github.com/microsoft/o…

视频实时检测

<template>
  <div class="container">
    <div class="content">
      <el-button type="primary" @click="toPage">照片识别</el-button>
      <div class="video-container" v-loading="state.loading">
        <video ref="videoRef" @loadeddata="handleVideoLoaded"></video>
        <div class="wrap">
          <canvas ref="canvasRef"></canvas>
        </div>
      </div>
      <div class="controls">
        <el-space>
          <el-button @click="handlePlay" :disabled="state.loading || state.isRunning">
            {{ state.isRunning ? '运行中' : '开始' }}
          </el-button>
          <el-button @click="handlePause" :disabled="!state.isRunning">暂停</el-button>
        </el-space>
        <div class="stats" v-if="state.inferCount > 0">
          <p>推理次数: {{ state.inferCount }}</p>
          <p>平均推理时间: {{ getAverageInferTime }} ms</p>
        </div>
      </div>
    </div>
  </div>
</template>

<script setup lang="ts">
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue';
// import * as ort from 'onnxruntime-web';
import * as ort from 'onnxruntime-web/webgpu';
import { process_output } from './utils/config';
import { useRouter } from 'vue-router';

const router = useRouter();

interface State {
  loading: boolean;
  isRunning: boolean;
  isVideoLoaded: boolean;
  inferCount: number;
  totalInferTime: number;
  boxes: any[];
}

const videoRef = ref<HTMLVideoElement | null>(null);
const canvasRef = ref<HTMLCanvasElement | null>(null);
let animationFrameId: number | null = null;
let model: ort.InferenceSession | null = null;

const state = reactive<State>({
  loading: true,
  isRunning: false,
  isVideoLoaded: false,
  inferCount: 0,
  totalInferTime: 0,
  boxes: []
});

const getAverageInferTime = computed(() => (state.inferCount ? (state.totalInferTime / state.inferCount).toFixed(2) : '0'));

const loadModel = async () => {
  try {
    const { VITE_PUBLIC_PATH } = import.meta.env;
    ort.env.wasm.wasmPaths = `${VITE_PUBLIC_PATH}wasm/`;

    // model = await ort.InferenceSession.create(`${VITE_PUBLIC_PATH}model/yolov8n.onnx`);
    model = await ort.InferenceSession.create(`${VITE_PUBLIC_PATH}model/yolov8n.onnx`, { executionProviders: ['webgpu'] });
    if (!model) {
      await loadModel();
    }

    const cameraInitialized = await initCamera();
    if (!cameraInitialized) return;
    state.loading = false;
  } catch (error) {
    console.error('Error loading model:', error);
    state.loading = false;
  }
};

const runModel = async (input: Float32Array): Promise<Float32Array | null> => {
  if (!model) return null;

  try {
    const tensor = new ort.Tensor(input, [1, 3, 640, 640]);
    const outputs = await model.run({ images: tensor });

    console.log(outputs, '111111111');

    return outputs['output0'].data as Float32Array;
  } catch (error) {
    console.error('Error running model:', error);
    return null;
  }
};

const initCamera = async (): Promise<boolean> => {
  if (!videoRef.value) return false;

  try {
    const stream = await navigator.mediaDevices.getUserMedia({ video: true });
    videoRef.value.srcObject = stream;
    return true;
  } catch (error) {
    console.error('Failed to access camera:', error);
    return false;
  }
};

const handleVideoLoaded = () => {
  state.isVideoLoaded = true;
  if (state.isRunning) {
    startInference();
  }
};

const prepareInput = (canvas: HTMLCanvasElement): Float32Array | null => {
  if (!canvas || canvas.width === 0 || canvas.height === 0) return null;

  const tempCanvas = document.createElement('canvas');
  tempCanvas.width = 640;
  tempCanvas.height = 640;
  const ctx = tempCanvas.getContext('2d');

  if (!ctx) return null;

  ctx.drawImage(canvas, 0, 0, 640, 640);
  const imageData = ctx.getImageData(0, 0, 640, 640).data;
  const input = new Float32Array(640 * 640 * 3);

  for (let i = 0, j = 0; i < imageData.length; i += 4, j++) {
    input[j] = imageData[i] / 255; // Red
    input[j + 640 * 640] = imageData[i + 1] / 255; // Green
    input[j + 2 * 640 * 640] = imageData[i + 2] / 255; // Blue
  }

  return input;
};

const drawBoxes = async (canvas: HTMLCanvasElement, boxes: any[]) => {
  const ctx = canvas.getContext('2d');
  if (!ctx) return;

  ctx.clearRect(0, 0, canvas.width, canvas.height);

  ctx.save();
  ctx.strokeStyle = '#00FF00';
  ctx.lineWidth = 3;
  ctx.font = '18px serif';

  boxes.forEach(([x1, y1, x2, y2, label]) => {
    ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);

    ctx.fillStyle = '#00ff00';
    const width = ctx.measureText(label).width;
    ctx.fillRect(x1, y1, width + 10, 25);

    ctx.fillStyle = '#000000';
    ctx.fillText(label, x1, y1 + 18);
  });

  ctx.restore();
};

const startInference = async () => {
  const video = videoRef.value;
  const canvas = canvasRef.value;

  if (!video || !canvas) return;

  const ctx = canvas.getContext('2d');
  if (!ctx) return;

  canvas.width = video.videoWidth;
  canvas.height = video.videoHeight;

  let animationFrameId = null as any;
  let frameCount = 0; // 帧计数器

  const processFrame = async () => {
    if (!state.isRunning) {
      cancelAnimationFrame(animationFrameId);
      return;
    }

    frameCount++;

    // 每 2 帧处理一次
    if (frameCount % 2 === 0) {
      ctx.drawImage(video, 0, 0);

      const input = prepareInput(canvas);
      if (input) {
        try {
          const startTime = performance.now();
          const output = await runModel(input);
          const endTime = performance.now();

          if (output) {
            state.inferCount++;
            state.totalInferTime += endTime - startTime;
            state.boxes = process_output(output, canvas.width, canvas.height);

            console.log(state.boxes, '检测到的框');

            await drawBoxes(canvas, state.boxes);
          }
        } catch (error) {
          console.error('推理过程中发生错误:', error);
        }
      }

      frameCount = 0; // 重置帧计数器
    }

    animationFrameId = requestAnimationFrame(processFrame);
  };

  // 开始执行帧处理
  processFrame();
};

const handlePlay = async () => {
  if (state.isRunning) return;

  state.isRunning = true;
  videoRef.value?.play();

  if (state.isVideoLoaded) {
    startInference();
  }
};

const handlePause = () => {
  state.isRunning = false;
  videoRef.value?.pause();
  if (animationFrameId) {
    cancelAnimationFrame(animationFrameId);
    animationFrameId = null;
  }
};

const toPage = () => {
  router.push('/img');
};

onMounted(async () => {
  await loadModel();
});

onUnmounted(() => {
  handlePause();
});
</script>

<style scoped>
.container {
  width: 100%;
  height: 100vh;
  display: flex;
  justify-content: center;
  align-items: center;
}

.content {
  display: flex;
  flex-direction: column;
  align-items: center;
  gap: 20px;
}

.video-container {
  position: relative;
  width: 640px;
  height: 480px;
  border: 1px solid #eee;
  border-radius: 8px;
  overflow: hidden;
}

video {
  width: 100%;
  height: 100%;
  object-fit: cover;
}

.wrap {
  position: absolute;
  top: 0;
  left: 0;
  width: 100%;
  height: 100%;
}

canvas {
  width: 100%;
  height: 100%;
  pointer-events: none;
}

.controls {
  padding: 10px;
  display: flex;
  flex-direction: column;
  align-items: center;
  gap: 10px;
}

.stats {
  text-align: center;
  font-size: 14px;
  color: #fff;
}

.stats p {
  margin: 5px 0;
}
</style>

上传照片检测

<template>
  <div class="container">
    <div class="content">
      <el-button type="primary" @click="toPage">实时识别</el-button>
      <div class="image-container" v-loading="state.loading">
        <div class="wrap">
          <canvas ref="canvasRef"></canvas>
        </div>
      </div>
      <div class="controls">
        <el-space>
          <el-upload class="upload-demo" action="" :auto-upload="false" :show-file-list="false" accept="image/*" @change="handleFileChange">
            <el-button type="primary">点击上传</el-button>
          </el-upload>
          <el-button @click="handleClear" :disabled="!state.hasImage">清除</el-button>
        </el-space>

        <div class="stats" v-if="state.inferCount > 0">
          <p>推理次数: {{ state.inferCount }}</p>
          <p>推理时间: {{ state.lastInferTime.toFixed(2) }} ms</p>
        </div>
      </div>
    </div>
  </div>
</template>

<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue';
import * as ort from 'onnxruntime-web/webgpu';
import { process_output } from './utils/config';
import { useRouter } from 'vue-router';

const router = useRouter();

interface State {
  loading: boolean;
  hasImage: boolean;
  inferCount: number;
  lastInferTime: number;
  boxes: any[];
}

const canvasRef = ref<HTMLCanvasElement | null>(null);
let model: ort.InferenceSession | null = null;

const state = reactive<State>({
  loading: true,
  hasImage: false,
  inferCount: 0,
  lastInferTime: 0,
  boxes: []
});

const loadModel = async () => {
  try {
    const { VITE_PUBLIC_PATH } = import.meta.env;

    ort.env.wasm.wasmPaths = `${VITE_PUBLIC_PATH}wasm/`;

    console.log(ort, '11111');
    model = await ort.InferenceSession.create(`${VITE_PUBLIC_PATH}model/yolov8n.onnx`, {
      executionProviders: ['webgpu']
    });
    if (!model) {
      await loadModel();
    }
    state.loading = false;
  } catch (error) {
    console.error('Error loading model:', error);
    state.loading = false;
  }
};

const runModel = async (input: Float32Array): Promise<Float32Array | null> => {
  if (!model) return null;

  try {
    const tensor = new ort.Tensor(input, [1, 3, 640, 640]);
    const outputs = await model.run({ images: tensor });

    console.log(outputs, 'outputs');

    return outputs['output0'].data as Float32Array;
  } catch (error) {
    console.error('Error running model:', error);
    return null;
  }
};

const prepareInput = (canvas: HTMLCanvasElement): Float32Array | null => {
  if (!canvas || canvas.width === 0 || canvas.height === 0) return null;

  const tempCanvas = document.createElement('canvas');
  tempCanvas.width = 640;
  tempCanvas.height = 640;
  const ctx = tempCanvas.getContext('2d');

  if (!ctx) return null;

  ctx.drawImage(canvas, 0, 0, 640, 640);
  const imageData = ctx.getImageData(0, 0, 640, 640).data;
  const input = new Float32Array(640 * 640 * 3);

  for (let i = 0, j = 0; i < imageData.length; i += 4, j++) {
    input[j] = imageData[i] / 255;
    input[j + 640 * 640] = imageData[i + 1] / 255;
    input[j + 2 * 640 * 640] = imageData[i + 2] / 255;
  }

  return input;
};

const drawBoxes = (canvas: HTMLCanvasElement, boxes: any[]) => {
  const ctx = canvas.getContext('2d');
  if (!ctx) return;

  ctx.strokeStyle = '#00FF00';
  ctx.lineWidth = 3;
  ctx.font = '18px serif';

  boxes.forEach(([x1, y1, x2, y2, label]) => {
    ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);

    ctx.fillStyle = '#00ff00';
    const width = ctx.measureText(label).width;
    ctx.fillRect(x1, y1, width + 10, 25);

    ctx.fillStyle = '#000000';
    ctx.fillText(label, x1, y1 + 18);
  });
};

const detect = async (canvas: HTMLCanvasElement) => {
  const input = prepareInput(canvas);
  if (!input) return;

  console.log(input, 'input');

  const startTime = performance.now();
  const output = await runModel(input);

  console.log(output, 'output');

  const endTime = performance.now();

  if (output) {
    state.inferCount++;
    state.lastInferTime = endTime - startTime;
    state.boxes = process_output(output, canvas.width, canvas.height);

    console.log(state.boxes, 'state.boxes');

    drawBoxes(canvas, state.boxes);
  }
};

const handleFileChange = async (file: any) => {
  state.loading = true;
  if (!canvasRef.value) return;

  const canvas = canvasRef.value;
  const ctx = canvas.getContext('2d');
  if (!ctx) return;

  const img = new Image();
  img.onload = async () => {
    canvas.width = img.width;
    canvas.height = img.height;
    ctx.drawImage(img, 0, 0);
    state.hasImage = true;
    // Automatically run detection after image is loaded
    await detect(canvas);
  };
  img.src = URL.createObjectURL(file.raw);

  state.loading = false;
};

const handleClear = () => {
  if (!canvasRef.value) return;

  const ctx = canvasRef.value.getContext('2d');
  if (!ctx) return;

  ctx.clearRect(0, 0, canvasRef.value.width, canvasRef.value.height);
  state.hasImage = false;
  state.boxes = [];
};

const toPage = () => {
  router.push('/model');
};

onMounted(async () => {
  await loadModel();
});
</script>

<style scoped>
.container {
  width: 100%;
  height: 100vh;
  display: flex;
  justify-content: center;
  align-items: center;
  position: relative;
}

.topage {
  position: absolute;
  top: 0;
  left: 0;
}

.content {
  display: flex;
  flex-direction: column;
  align-items: center;
  gap: 20px;
}

.image-container {
  position: relative;
  width: 640px;
  min-height: 480px;
  border: 1px solid #eee;
  border-radius: 8px;
  overflow: hidden;
  display: flex;
  justify-content: center;
  align-items: center;
}

.wrap {
  width: 100%;
  height: 100%;
}

canvas {
  max-width: 100%;
  max-height: 100%;
  pointer-events: none;
}

.controls {
  padding: 10px;
  display: flex;
  flex-direction: column;
  align-items: center;
  gap: 10px;
}

.stats {
  text-align: center;
  font-size: 14px;
  color: #fff;
}

.stats p {
  margin: 5px 0;
}

:deep(.el-upload-dragger) {
  width: 360px;
  height: 180px;
}
</style>

config

const yolo_classes = [
  '人',
  '自行车',
  '汽车',
  '摩托车',
  '飞机',
  '公共汽车',
  '火车',
  '卡车',
  '船',
  '交通灯',
  '消防栓',
  '停车标志',
  '停车表',
  '长椅',
  '鸟',
  '猫',
  '狗',
  '马',
  '羊',
  '牛',
  '大象',
  '熊',
  '斑马',
  '长颈鹿',
  '背包',
  '雨伞',
  '手提包',
  '领带',
  '行李箱',
  '飞盘',
  '滑雪板',
  '滑雪板',
  '运动球',
  '风筝',
  '棒球棒',
  '棒球手套',
  '滑板',
  '冲浪板',
  '网球拍',
  '瓶子',
  '酒杯',
  '杯子',
  '叉子',
  '刀',
  '勺子',
  '碗',
  '香蕉',
  '苹果',
  '三明治',
  '橙子',
  '西兰花',
  '胡萝卜',
  '热狗',
  '披萨',
  '甜甜圈',
  '蛋糕',
  '椅子',
  '沙发',
  '盆栽植物',
  '床',
  '餐桌',
  '厕所',
  '电视',
  '笔记本电脑',
  '鼠标',
  '遥控器',
  '键盘',
  '手机',
  '微波炉',
  '烤箱',
  '烤面包机',
  '水槽',
  '冰箱',
  '书',
  '时钟',
  '花瓶',
  '剪刀',
  '泰迪熊',
  '吹风机',
  '牙刷'
];

export const iou = (box1: any, box2: any) => {
  return intersection(box1, box2) / union(box1, box2);
};

export const intersection = (box1: any, box2: any) => {
  const [box1_x1, box1_y1, box1_x2, box1_y2] = box1;
  const [box2_x1, box2_y1, box2_x2, box2_y2] = box2;
  const x1 = Math.max(box1_x1, box2_x1);
  const y1 = Math.max(box1_y1, box2_y1);
  const x2 = Math.min(box1_x2, box2_x2);
  const y2 = Math.min(box1_y2, box2_y2);
  return (x2 - x1) * (y2 - y1);
};

export const union = (box1: any, box2: any) => {
  const [box1_x1, box1_y1, box1_x2, box1_y2] = box1;
  const [box2_x1, box2_y1, box2_x2, box2_y2] = box2;
  const box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1);
  const box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1);
  return box1_area + box2_area - intersection(box1, box2);
};

export const process_output = (output: any, img_width: number, img_height: number): any[] => {
  let boxes = [] as any;
  for (let index = 0; index < 8400; index++) {
    const [class_id, prob] = [...Array(yolo_classes.length).keys()]
      .map(col => [col, output[8400 * (col + 4) + index]])
      .reduce((accum, item) => (item[1] > accum[1] ? item : accum), [0, 0]);
    if (prob < 0.5) {
      continue;
    }
    const label = yolo_classes[class_id];
    const xc = output[index];
    const yc = output[8400 + index];
    const w = output[2 * 8400 + index];
    const h = output[3 * 8400 + index];
    const x1 = ((xc - w / 2) / 640) * img_width;
    const y1 = ((yc - h / 2) / 640) * img_height;
    const x2 = ((xc + w / 2) / 640) * img_width;
    const y2 = ((yc + h / 2) / 640) * img_height;
    boxes.push([x1, y1, x2, y2, label, prob]);
  }

  boxes = boxes.sort((box1: any, box2: any) => box2[5] - box1[5]);
  const result = [];
  while (boxes.length > 0) {
    result.push(boxes[0]);
    boxes = boxes.filter((box: any) => iou(boxes[0], box) < 0.7);
  }
  return result;
};

export const prepare_input = (img: HTMLCanvasElement) => {
  console.log(img, '1111');

  if (!img || img.width === 0 || img.height === 0) {
    console.error('Invalid canvas size');
    return null;
  }

  const canvas = document.createElement('canvas') as HTMLCanvasElement;
  canvas.width = 640;
  canvas.height = 640;
  const context = canvas.getContext('2d');

  if (!context) return;

  context.drawImage(img, 0, 0, 640, 640);
  const data = context.getImageData(0, 0, 640, 640).data;
  const red = [],
    green = [],
    blue = [];
  for (let index = 0; index < data.length; index += 4) {
    red.push(data[index] / 255);
    green.push(data[index + 1] / 255);
    blue.push(data[index + 2] / 255);
  }
  return [...red, ...green, ...blue];
};

参考链接

rps.regulusai.top/

基于onnx的web端yolov8模型部署与推理

查看模型信息

netron.app/

// 旧版本 使用GPU (v 1.8.0为例)
import * as ort from 'onnxruntime-web';
await ort.InferenceSession.create(
'模型路径',
 {executionProviders: ['webgpu']}
)


// 新版本 使用GPU (v 1.20.1为例)
import * as ort from 'onnxruntime-web/webgpu';
await ort.InferenceSession.create(、
'模型路径',
{executionProviders: ['webgpu']}
)

东拼西凑写的不好,将就看

仓库地址