使用 TensorFlow.js 实现图像分类

390 阅读3分钟

在当今科技迅猛发展的时代,人工智能(AI)已成为各行各业不可或缺的一部分。Iphone集成AI,豆包,文心一言,GPT等的兴起,AI化已成为趋势,模型训练、NLP、前端与AI的结合等可能是我们未来要掌握的一种技能。在这篇文章中,我们将深入探讨如何使用 TensorFlow.js 构建一个简单的图像分类器,利用 MobileNet 模型进行实时分类,并结合 React 框架来实现前端交互。让我们一起走进前端和AI的结合的世界!

项目概述

我们的项目目标是实现一个图像分类应用,用户可以上传图片,系统通过预训练的 MobileNet 模型识别出图片中的物体,并展示分类结果。通过这个项目,读者不仅可以掌握图像分类的基本概念,还能学习如何在前端环境中使用 TensorFlow.js

环境准备

首先,我们需要确保我们的开发环境中安装了 Node.js 和 npm。接下来,使用以下命令创建一个新的 React 应用:

npx create-react-app image-classifier
cd image-classifier

然后,我们需要安装 TensorFlow.js:

npm install @tensorflow/tfjs

此外,我们需要准备一个类别标签文件 imagenet_class_index.json,它包含了 ImageNet 数据集中每个类别的 ID 和名称。

imagenet_class_index.json

代码实现

以下是项目的核心代码,主要分为几个部分:

  1. 加载模型:使用 TensorFlow.js 加载 MobileNet 模型。
  2. 处理用户上传的图像:读取并预处理用户上传的图片。
  3. 进行分类预测:利用模型对处理后的图片进行分类,并获取预测结果。
import React, { useState, useRef, useEffect } from 'react';
import * as tf from '@tensorflow/tfjs';
import labelsData from './imagenet_class_index.json';
import './App.css';
​
function App() {
    const [imageURL, setImageURL] = useState(null);
    const [predictions, setPredictions] = useState([]);
    const imageRef = useRef();
    const modelRef = useRef();
​
    useEffect(() => {
        const loadModel = async () => {
            console.log('Loading model...');
            modelRef.current = await tf.loadGraphModel(
                'https://www.kaggle.com/models/google/mobilenet-v2/TfJs/100-224-classification/3',
                { fromTFHub: true }
            );
            console.log('Model loaded successfully');
        };
        loadModel();
    }, []);
​
    const handleUpload = (e) => {
        const file = e.target.files[0];
        const reader = new FileReader();
        reader.onload = () => setImageURL(reader.result);
        reader.readAsDataURL(file);
    };
​
    const classifyImage = async () => {
        if (!modelRef.current) {
            console.error('Model not loaded yet');
            return;
        }
​
        const image = tf.browser.fromPixels(imageRef.current);
        const resized = tf.image.resizeBilinear(image, [224, 224]).expandDims(0).div(255);
        const prediction = await modelRef.current.predict(resized).array();
​
        const topPrediction = prediction[0]
            .map((p, i) => ({ class: i, probability: p }))
            .sort((a, b) => b.probability - a.probability)[0];
​
        const className = labelsData[topPrediction.class - 1] ? labelsData[topPrediction.class - 1][1] : "未知类别";
​
        // TODO: 置信度可能存在bug
        const probability = Math.min((topPrediction.probability * 100).toFixed(2), 100);
​
        setPredictions([{ class: topPrediction.class, className, probability }]);
        image.dispose();
    };
​
    return (
        <div className="App">
            <h1>Image Classifier with TensorFlow.js</h1>
            <input type="file" accept="image/*" onChange={handleUpload} />
            {imageURL && (
                <div>
                    <img src={imageURL} alt="Upload Preview" ref={imageRef} onLoad={classifyImage} />
                    <h3>Predictions:</h3>
                    <ul>
                        {predictions.map((p, index) => (
                            <li key={index}>Class: {p.class}, Name: {p.className}, Probability: {p.probability}%</li>
                        ))}
                    </ul>
                </div>
            )}
        </div>
    );
}
​
export default App;

代码解析

  1. 模型加载:使用 tf.loadGraphModel 方法异步加载预训练的 MobileNet 模型,确保在组件加载时只调用一次。
  2. 图片上传和预处理:通过 FileReader 读取用户上传的图片,并将其转换为可用于 TensorFlow.js 的格式。
  3. 预测过程:对图像进行预处理(调整大小、归一化等),然后调用模型进行预测,获取最可能的分类结果并更新 UI。

好了,我们的demo就做完了 是不是很简单呢!

1730265850520.jpg

实际应用与思考

本项目展示了如何在前端环境中结合 TensorFlow.js 和 React 进行图像分类。随着 AI 技术的不断进步,图像识别将会在多个领域产生深远影响,例如医疗诊断、自动驾驶、安防监控等。未来,我们可以进一步扩展项目功能,支持多图分类、实时视频分类等高级特性。

总结

通过本项目,我们不仅掌握了 TensorFlow.js 的基本用法,还加深了对图像分类技术的理解。在 AI 迅速发展的背景下,掌握相关技术将有助于我们在未来的职业发展中把握机遇。


后续我们将自己训练一些小的模型,然后前端和我们的模型深度结合,构建属于我们自己的AI。