YOLO(You Only Look Once)是一个广泛用于目标检测的高效模型。YOLOv8 是该系列的最新版本之一,提供了极快的速度和较高的精度。在这篇文章中,我们将通过 onnxruntime-web
来加载和运行 YOLOv8n(YOLOv8 nano)ONNX 模型,并在 Vue 3 中展示检测结果。
1. 准备工作
在开始之前,请确保你已经具备以下环境和依赖项:
- Node.js 和 npm 已安装。
- 基本的 Vue 3 项目。
已导出的 yolov8n.onnx
模型(可以从 这里 获取)。
- 安装
onnxruntime-web
库。( Version 1.20.1 )
安装 onnxruntime-web
在 Vue 项目的根目录下,安装 onnxruntime-web
:
npm install onnxruntime-web
相关链接
www.npmjs.com/package/onn…
onnxruntime.ai/docs/tutori…
github.com/Microsoft/o…
github.com/microsoft/o…
视频实时检测
<template>
<div class="container">
<div class="content">
<el-button type="primary" @click="toPage">照片识别</el-button>
<div class="video-container" v-loading="state.loading">
<video ref="videoRef" @loadeddata="handleVideoLoaded"></video>
<div class="wrap">
<canvas ref="canvasRef"></canvas>
</div>
</div>
<div class="controls">
<el-space>
<el-button @click="handlePlay" :disabled="state.loading || state.isRunning">
{{ state.isRunning ? '运行中' : '开始' }}
</el-button>
<el-button @click="handlePause" :disabled="!state.isRunning">暂停</el-button>
</el-space>
<div class="stats" v-if="state.inferCount > 0">
<p>推理次数: {{ state.inferCount }}</p>
<p>平均推理时间: {{ getAverageInferTime }} ms</p>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue';
// import * as ort from 'onnxruntime-web';
import * as ort from 'onnxruntime-web/webgpu';
import { process_output } from './utils/config';
import { useRouter } from 'vue-router';
const router = useRouter();
interface State {
loading: boolean;
isRunning: boolean;
isVideoLoaded: boolean;
inferCount: number;
totalInferTime: number;
boxes: any[];
}
const videoRef = ref<HTMLVideoElement | null>(null);
const canvasRef = ref<HTMLCanvasElement | null>(null);
let animationFrameId: number | null = null;
let model: ort.InferenceSession | null = null;
const state = reactive<State>({
loading: true,
isRunning: false,
isVideoLoaded: false,
inferCount: 0,
totalInferTime: 0,
boxes: []
});
const getAverageInferTime = computed(() => (state.inferCount ? (state.totalInferTime / state.inferCount).toFixed(2) : '0'));
const loadModel = async () => {
try {
const { VITE_PUBLIC_PATH } = import.meta.env;
ort.env.wasm.wasmPaths = `${VITE_PUBLIC_PATH}wasm/`;
// model = await ort.InferenceSession.create(`${VITE_PUBLIC_PATH}model/yolov8n.onnx`);
model = await ort.InferenceSession.create(`${VITE_PUBLIC_PATH}model/yolov8n.onnx`, { executionProviders: ['webgpu'] });
if (!model) {
await loadModel();
}
const cameraInitialized = await initCamera();
if (!cameraInitialized) return;
state.loading = false;
} catch (error) {
console.error('Error loading model:', error);
state.loading = false;
}
};
const runModel = async (input: Float32Array): Promise<Float32Array | null> => {
if (!model) return null;
try {
const tensor = new ort.Tensor(input, [1, 3, 640, 640]);
const outputs = await model.run({ images: tensor });
console.log(outputs, '111111111');
return outputs['output0'].data as Float32Array;
} catch (error) {
console.error('Error running model:', error);
return null;
}
};
const initCamera = async (): Promise<boolean> => {
if (!videoRef.value) return false;
try {
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
videoRef.value.srcObject = stream;
return true;
} catch (error) {
console.error('Failed to access camera:', error);
return false;
}
};
const handleVideoLoaded = () => {
state.isVideoLoaded = true;
if (state.isRunning) {
startInference();
}
};
const prepareInput = (canvas: HTMLCanvasElement): Float32Array | null => {
if (!canvas || canvas.width === 0 || canvas.height === 0) return null;
const tempCanvas = document.createElement('canvas');
tempCanvas.width = 640;
tempCanvas.height = 640;
const ctx = tempCanvas.getContext('2d');
if (!ctx) return null;
ctx.drawImage(canvas, 0, 0, 640, 640);
const imageData = ctx.getImageData(0, 0, 640, 640).data;
const input = new Float32Array(640 * 640 * 3);
for (let i = 0, j = 0; i < imageData.length; i += 4, j++) {
input[j] = imageData[i] / 255; // Red
input[j + 640 * 640] = imageData[i + 1] / 255; // Green
input[j + 2 * 640 * 640] = imageData[i + 2] / 255; // Blue
}
return input;
};
const drawBoxes = async (canvas: HTMLCanvasElement, boxes: any[]) => {
const ctx = canvas.getContext('2d');
if (!ctx) return;
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.save();
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 3;
ctx.font = '18px serif';
boxes.forEach(([x1, y1, x2, y2, label]) => {
ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
ctx.fillStyle = '#00ff00';
const width = ctx.measureText(label).width;
ctx.fillRect(x1, y1, width + 10, 25);
ctx.fillStyle = '#000000';
ctx.fillText(label, x1, y1 + 18);
});
ctx.restore();
};
const startInference = async () => {
const video = videoRef.value;
const canvas = canvasRef.value;
if (!video || !canvas) return;
const ctx = canvas.getContext('2d');
if (!ctx) return;
canvas.width = video.videoWidth;
canvas.height = video.videoHeight;
let animationFrameId = null as any;
let frameCount = 0; // 帧计数器
const processFrame = async () => {
if (!state.isRunning) {
cancelAnimationFrame(animationFrameId);
return;
}
frameCount++;
// 每 2 帧处理一次
if (frameCount % 2 === 0) {
ctx.drawImage(video, 0, 0);
const input = prepareInput(canvas);
if (input) {
try {
const startTime = performance.now();
const output = await runModel(input);
const endTime = performance.now();
if (output) {
state.inferCount++;
state.totalInferTime += endTime - startTime;
state.boxes = process_output(output, canvas.width, canvas.height);
console.log(state.boxes, '检测到的框');
await drawBoxes(canvas, state.boxes);
}
} catch (error) {
console.error('推理过程中发生错误:', error);
}
}
frameCount = 0; // 重置帧计数器
}
animationFrameId = requestAnimationFrame(processFrame);
};
// 开始执行帧处理
processFrame();
};
const handlePlay = async () => {
if (state.isRunning) return;
state.isRunning = true;
videoRef.value?.play();
if (state.isVideoLoaded) {
startInference();
}
};
const handlePause = () => {
state.isRunning = false;
videoRef.value?.pause();
if (animationFrameId) {
cancelAnimationFrame(animationFrameId);
animationFrameId = null;
}
};
const toPage = () => {
router.push('/img');
};
onMounted(async () => {
await loadModel();
});
onUnmounted(() => {
handlePause();
});
</script>
<style scoped>
.container {
width: 100%;
height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.content {
display: flex;
flex-direction: column;
align-items: center;
gap: 20px;
}
.video-container {
position: relative;
width: 640px;
height: 480px;
border: 1px solid #eee;
border-radius: 8px;
overflow: hidden;
}
video {
width: 100%;
height: 100%;
object-fit: cover;
}
.wrap {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
}
canvas {
width: 100%;
height: 100%;
pointer-events: none;
}
.controls {
padding: 10px;
display: flex;
flex-direction: column;
align-items: center;
gap: 10px;
}
.stats {
text-align: center;
font-size: 14px;
color: #fff;
}
.stats p {
margin: 5px 0;
}
</style>
上传照片检测
<template>
<div class="container">
<div class="content">
<el-button type="primary" @click="toPage">实时识别</el-button>
<div class="image-container" v-loading="state.loading">
<div class="wrap">
<canvas ref="canvasRef"></canvas>
</div>
</div>
<div class="controls">
<el-space>
<el-upload class="upload-demo" action="" :auto-upload="false" :show-file-list="false" accept="image/*" @change="handleFileChange">
<el-button type="primary">点击上传</el-button>
</el-upload>
<el-button @click="handleClear" :disabled="!state.hasImage">清除</el-button>
</el-space>
<div class="stats" v-if="state.inferCount > 0">
<p>推理次数: {{ state.inferCount }}</p>
<p>推理时间: {{ state.lastInferTime.toFixed(2) }} ms</p>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue';
import * as ort from 'onnxruntime-web/webgpu';
import { process_output } from './utils/config';
import { useRouter } from 'vue-router';
const router = useRouter();
interface State {
loading: boolean;
hasImage: boolean;
inferCount: number;
lastInferTime: number;
boxes: any[];
}
const canvasRef = ref<HTMLCanvasElement | null>(null);
let model: ort.InferenceSession | null = null;
const state = reactive<State>({
loading: true,
hasImage: false,
inferCount: 0,
lastInferTime: 0,
boxes: []
});
const loadModel = async () => {
try {
const { VITE_PUBLIC_PATH } = import.meta.env;
ort.env.wasm.wasmPaths = `${VITE_PUBLIC_PATH}wasm/`;
console.log(ort, '11111');
model = await ort.InferenceSession.create(`${VITE_PUBLIC_PATH}model/yolov8n.onnx`, {
executionProviders: ['webgpu']
});
if (!model) {
await loadModel();
}
state.loading = false;
} catch (error) {
console.error('Error loading model:', error);
state.loading = false;
}
};
const runModel = async (input: Float32Array): Promise<Float32Array | null> => {
if (!model) return null;
try {
const tensor = new ort.Tensor(input, [1, 3, 640, 640]);
const outputs = await model.run({ images: tensor });
console.log(outputs, 'outputs');
return outputs['output0'].data as Float32Array;
} catch (error) {
console.error('Error running model:', error);
return null;
}
};
const prepareInput = (canvas: HTMLCanvasElement): Float32Array | null => {
if (!canvas || canvas.width === 0 || canvas.height === 0) return null;
const tempCanvas = document.createElement('canvas');
tempCanvas.width = 640;
tempCanvas.height = 640;
const ctx = tempCanvas.getContext('2d');
if (!ctx) return null;
ctx.drawImage(canvas, 0, 0, 640, 640);
const imageData = ctx.getImageData(0, 0, 640, 640).data;
const input = new Float32Array(640 * 640 * 3);
for (let i = 0, j = 0; i < imageData.length; i += 4, j++) {
input[j] = imageData[i] / 255;
input[j + 640 * 640] = imageData[i + 1] / 255;
input[j + 2 * 640 * 640] = imageData[i + 2] / 255;
}
return input;
};
const drawBoxes = (canvas: HTMLCanvasElement, boxes: any[]) => {
const ctx = canvas.getContext('2d');
if (!ctx) return;
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 3;
ctx.font = '18px serif';
boxes.forEach(([x1, y1, x2, y2, label]) => {
ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
ctx.fillStyle = '#00ff00';
const width = ctx.measureText(label).width;
ctx.fillRect(x1, y1, width + 10, 25);
ctx.fillStyle = '#000000';
ctx.fillText(label, x1, y1 + 18);
});
};
const detect = async (canvas: HTMLCanvasElement) => {
const input = prepareInput(canvas);
if (!input) return;
console.log(input, 'input');
const startTime = performance.now();
const output = await runModel(input);
console.log(output, 'output');
const endTime = performance.now();
if (output) {
state.inferCount++;
state.lastInferTime = endTime - startTime;
state.boxes = process_output(output, canvas.width, canvas.height);
console.log(state.boxes, 'state.boxes');
drawBoxes(canvas, state.boxes);
}
};
const handleFileChange = async (file: any) => {
state.loading = true;
if (!canvasRef.value) return;
const canvas = canvasRef.value;
const ctx = canvas.getContext('2d');
if (!ctx) return;
const img = new Image();
img.onload = async () => {
canvas.width = img.width;
canvas.height = img.height;
ctx.drawImage(img, 0, 0);
state.hasImage = true;
// Automatically run detection after image is loaded
await detect(canvas);
};
img.src = URL.createObjectURL(file.raw);
state.loading = false;
};
const handleClear = () => {
if (!canvasRef.value) return;
const ctx = canvasRef.value.getContext('2d');
if (!ctx) return;
ctx.clearRect(0, 0, canvasRef.value.width, canvasRef.value.height);
state.hasImage = false;
state.boxes = [];
};
const toPage = () => {
router.push('/model');
};
onMounted(async () => {
await loadModel();
});
</script>
<style scoped>
.container {
width: 100%;
height: 100vh;
display: flex;
justify-content: center;
align-items: center;
position: relative;
}
.topage {
position: absolute;
top: 0;
left: 0;
}
.content {
display: flex;
flex-direction: column;
align-items: center;
gap: 20px;
}
.image-container {
position: relative;
width: 640px;
min-height: 480px;
border: 1px solid #eee;
border-radius: 8px;
overflow: hidden;
display: flex;
justify-content: center;
align-items: center;
}
.wrap {
width: 100%;
height: 100%;
}
canvas {
max-width: 100%;
max-height: 100%;
pointer-events: none;
}
.controls {
padding: 10px;
display: flex;
flex-direction: column;
align-items: center;
gap: 10px;
}
.stats {
text-align: center;
font-size: 14px;
color: #fff;
}
.stats p {
margin: 5px 0;
}
:deep(.el-upload-dragger) {
width: 360px;
height: 180px;
}
</style>
config
const yolo_classes = [
'人',
'自行车',
'汽车',
'摩托车',
'飞机',
'公共汽车',
'火车',
'卡车',
'船',
'交通灯',
'消防栓',
'停车标志',
'停车表',
'长椅',
'鸟',
'猫',
'狗',
'马',
'羊',
'牛',
'大象',
'熊',
'斑马',
'长颈鹿',
'背包',
'雨伞',
'手提包',
'领带',
'行李箱',
'飞盘',
'滑雪板',
'滑雪板',
'运动球',
'风筝',
'棒球棒',
'棒球手套',
'滑板',
'冲浪板',
'网球拍',
'瓶子',
'酒杯',
'杯子',
'叉子',
'刀',
'勺子',
'碗',
'香蕉',
'苹果',
'三明治',
'橙子',
'西兰花',
'胡萝卜',
'热狗',
'披萨',
'甜甜圈',
'蛋糕',
'椅子',
'沙发',
'盆栽植物',
'床',
'餐桌',
'厕所',
'电视',
'笔记本电脑',
'鼠标',
'遥控器',
'键盘',
'手机',
'微波炉',
'烤箱',
'烤面包机',
'水槽',
'冰箱',
'书',
'时钟',
'花瓶',
'剪刀',
'泰迪熊',
'吹风机',
'牙刷'
];
export const iou = (box1: any, box2: any) => {
return intersection(box1, box2) / union(box1, box2);
};
export const intersection = (box1: any, box2: any) => {
const [box1_x1, box1_y1, box1_x2, box1_y2] = box1;
const [box2_x1, box2_y1, box2_x2, box2_y2] = box2;
const x1 = Math.max(box1_x1, box2_x1);
const y1 = Math.max(box1_y1, box2_y1);
const x2 = Math.min(box1_x2, box2_x2);
const y2 = Math.min(box1_y2, box2_y2);
return (x2 - x1) * (y2 - y1);
};
export const union = (box1: any, box2: any) => {
const [box1_x1, box1_y1, box1_x2, box1_y2] = box1;
const [box2_x1, box2_y1, box2_x2, box2_y2] = box2;
const box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1);
const box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1);
return box1_area + box2_area - intersection(box1, box2);
};
export const process_output = (output: any, img_width: number, img_height: number): any[] => {
let boxes = [] as any;
for (let index = 0; index < 8400; index++) {
const [class_id, prob] = [...Array(yolo_classes.length).keys()]
.map(col => [col, output[8400 * (col + 4) + index]])
.reduce((accum, item) => (item[1] > accum[1] ? item : accum), [0, 0]);
if (prob < 0.5) {
continue;
}
const label = yolo_classes[class_id];
const xc = output[index];
const yc = output[8400 + index];
const w = output[2 * 8400 + index];
const h = output[3 * 8400 + index];
const x1 = ((xc - w / 2) / 640) * img_width;
const y1 = ((yc - h / 2) / 640) * img_height;
const x2 = ((xc + w / 2) / 640) * img_width;
const y2 = ((yc + h / 2) / 640) * img_height;
boxes.push([x1, y1, x2, y2, label, prob]);
}
boxes = boxes.sort((box1: any, box2: any) => box2[5] - box1[5]);
const result = [];
while (boxes.length > 0) {
result.push(boxes[0]);
boxes = boxes.filter((box: any) => iou(boxes[0], box) < 0.7);
}
return result;
};
export const prepare_input = (img: HTMLCanvasElement) => {
console.log(img, '1111');
if (!img || img.width === 0 || img.height === 0) {
console.error('Invalid canvas size');
return null;
}
const canvas = document.createElement('canvas') as HTMLCanvasElement;
canvas.width = 640;
canvas.height = 640;
const context = canvas.getContext('2d');
if (!context) return;
context.drawImage(img, 0, 0, 640, 640);
const data = context.getImageData(0, 0, 640, 640).data;
const red = [],
green = [],
blue = [];
for (let index = 0; index < data.length; index += 4) {
red.push(data[index] / 255);
green.push(data[index + 1] / 255);
blue.push(data[index + 2] / 255);
}
return [...red, ...green, ...blue];
};
参考链接
查看模型信息
坑
// 旧版本 使用GPU (v 1.8.0为例)
import * as ort from 'onnxruntime-web';
await ort.InferenceSession.create(
'模型路径',
{executionProviders: ['webgpu']}
)
// 新版本 使用GPU (v 1.20.1为例)
import * as ort from 'onnxruntime-web/webgpu';
await ort.InferenceSession.create(、
'模型路径',
{executionProviders: ['webgpu']}
)
东拼西凑写的不好,将就看