使用TensorFlow的4个步骤进行Flutter图像分类

1,027 阅读3分钟

谷歌在TensorFlow家族中提供了一些产品。

  • TensorFlow:核心开放课程库,是开发和训练机器学习模型的基础。
  • TensorFlow.js:类似于TensorFlow,但纯粹专注于JavaScript
  • TensorFlow Lite:顾名思义,它是TensorFlow的一个轻量级版本,用于在移动设备上部署模型。它的功能有限,它只接受预训练的模型注入,并将模型加载到移动设备上。你可以用它来进行图像分类、物体检测和基于自然语言模型的问题/回答。
  • TensorFlow Production:它是TensorFlow在大型生产环境中的一个扩展。

在本文中,我将使用TensorFlow Lite将一个模型部署到Flutter应用程序中。不幸的是,在写这篇文章的时候,还没有一个官方的TensorFlow库用于Flutter,因此我们将使用一个第三方的lib,叫做tflite_flutter


第一步:给自己一个模型

要使用TensorFlow Lite,你必须将一个完整的TensorFlow模型转换为TensorFlow Lite,你不能使用这个库本身来训练一个模型,幸运的是Lite库带有很多预训练的模型,用于图像检测、物体检测、智能回复、姿势估计和分割。另外,你也可以从TensorFlow Hub找到预训练的模型,确保你选择的模型类型是TFLite。

TensorFlow Hub网站给了你很多关于如何注入模型的选择,你可以选择下载或直接导入到Android Studio。

如果你选择下载模型,你将收到的文件将被命名为*"some-image-classification-model.tflite",* 记得解压文件并提取标签,你以后需要model.tflite和标签.txt文件。

unzip some-image-classification-model.tflite

第2步:创建一个Flutter项目

先决条件:IntelliJ或VS Code IDE与Flutter构建环境

创建一个支持Android/iOS/Web的新Flutter项目,或者使用您现有的Flutter项目,如果您有的话。在您的根目录中,创建一个名为 "assets "的文件夹,并在该文件夹中保存您的 "label.txt "和 "model.tflite"。

接下来,到你的项目pubspec.yaml文件中添加以下依赖。

name: tensorflow
description: A new Flutter application.
version: 1.0.0+1
environment:
  sdk: ">=2.12.0 <3.0.0"

dependencies:
  flutter:
    sdk: flutter
  tflite: 1.1.2
  image_picker: 0.7.4

dev_dependencies:
  flutter_test:
    sdk: flutter

flutter:
  uses-material-design: true
  assets:
    - assets/model.tflite
    - assets/label.txt

第3步:编码时间

  • 创建一个Flutter主应用程序
void main() => runApp(MaterialApp(
      home: ImageDetectApp(),
    ));

class ImageDetectApp extends StatefulWidget {
  @override
  _ImageDetectState createState() => _ImageDetectState();
}
  • 创建一个_ImageDetectState类和启动Tflite库
class _ImageDetectState extends State<ImageDetectApp> {
  List? _listResult;
  PickedFile? _imageFile;
  bool _loading = false;

  @override
  void initState() {
    super.initState();
    _loading = true;
    _loadModel();
  }

  void _loadModel() async {
    await Tflite.loadModel(
      model: "assets/model.tflite",
      labels: "assets/label.txt",
    ).then((value) {
      setState(() {
        _loading = false;
      });
    });
  }
  • 在这个类中,创建一个浮动按钮(或任何点击事件),以接收用户的图像选择动作
floatingActionButton: FloatingActionButton(
  onPressed: _imageSelection,
  backgroundColor: Colors.blue,
  child: Icon(Icons.add_photo_alternate_outlined),
)
  • 添加图像选择功能
void _imageSelection() async {
  var imageFile = await ImagePicker().getImage(source: ImageSource.gallery);
  setState(() {
    _loading = true;
    _imageFile = imageFile;
  });
  _imageClasification(imageFile);
}
  • 添加图像分类功能
void _imageClasification(PickedFile image) async {
  var output = await Tflite.runModelOnImage(
    path: image.path,
    numResults: 2,
    threshold: 0.5,
    imageMean: 127.5,
    imageStd: 127.5,
  );
  setState(() {
    _loading = false;
    _listResult = output;
  });
}
  • 最后但并非最不重要:处置Tflite库
@override
void dispose() {
  Tflite.close();
  super.dispose();
}

运行该项目,就可以了


第4步Bouns。训练你自己的模型

有许多方法来训练你自己的模型,在这个例子中,我将使用谷歌colab(https://colab.research.google.com/),你可以运行相同的代码样本显示在这个演示从IDE。

  • 首先,安装软件包作为先决条件
!pip install -q tflite-model-maker

将上述代码添加到代码块中并点击运行,通过点击 "+代码 "符号添加另一个代码块,点击运行来执行以下代码。

import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
from tflite_model_maker.image_classifier import ModelSpec
import matplotlib.pyplot as plt
  • 第二,上传你的数据集

现在是收集图像的时候了!为了有一个准确的结果,你需要每组至少100多张图片,并将它们存储在一个文件夹内。

压缩food_images文件夹(或任何你喜欢使用的文件夹),把这个压缩文件上传到colab,成功上传后,下一步是解压缩(用新的代码块和执行)。

!unzip food_images.zip

  • 将数据加载到设备上的ML应用,并将其分成训练和测试数据(用新的代码块和执行)。

data = DataLoader.from_folder(‘/content/food_images’)
train_data, test_data = data.split(0.9)

from google.colab import drive
drive.mount('/content/drive')
  • 定制TensorFlow模型并评估它
model = image_classifier.create(train_data)
loss, accuracy = model.evaluate(test_data)

  • 输出TensorFlow Lite模型和它的标签
model.export(export_dir=’.’)
model.export(export_dir=’.’, export_format=ExportFormat.LABEL)

下载这个模型和标签,把它们导入到你的Flutter项目中!