浏览器中的机器学习:使用预训练模型

973 阅读4分钟

在上一篇文章《浏览器中的手写数字识别》中,讲到在浏览器中训练出一个卷积神经网络模型,用来识别手写数字。值得注意的是,这个训练过程是在浏览器中完成的,使用的是客户端的资源。

虽然TensorFlow.js的愿景是机器学习无处不在,即使是在手机、嵌入式设备上,只要运行有浏览器,都可以训练人工智能模型,但是考虑到手机、嵌入式设备有限的计算能力(虽然手机性能不断飞跃),复杂的人工智能模型还是交给更为强大的服务器来训练比较合适。况且目前主流的机器学习采用的是python语言,要让广大机器学习工程师从python转向js,估计大家也不会答应。

如果是这样的话,那TensorFlow.js推出还有何意义呢?

这个问题其实和TensorFlow Lite类似,我们可以在服务器端训练,在手机上使用训练出的模型进行推导,通常推导并不需要那么强大的计算能力。

在本文,我们将探索如何在TensorFlow.js中加载预训练的机器学习模型,完成图片分类任务。

在TensorFlow官网,访问 www.tensorflow.org/js/models/ 这个网址,可以看到里面有实时姿态预测模型、目标检测模型、语音识别模型、分类模型等等:

这里我们选择MobileNets模型。MobileNets是一种小型、低延迟、低耗能模型,满足各种资源受限的使用场景,可用于分类、检测、嵌入和分割,功能上类似于其他流行的大型模型(如Inception)。 MobileNets在延迟、大小和准确性之间取得了平衡。

有两种使用MobileNets模型的方案:

  1. 直接调用MobileNets模型的JS封装库
  2. 自己编写代码加载json格式的MobileNets模型

直接调用MobileNets模型的JS封装库

JS封装库直接将MobileNets模型封装为JS对象,我们就像调用普通的JS对象那样,调用对象方法,完成模型加载、推断。

比如访问 github.com/tensorflow/… ,我们可以看到该mobilenet对象提供两个主要的API:

mobilenet.load(
  version?: 1,
  alpha?: 0.25 | .50 | .75 | 1.0
)

参数:

  • 版本:MobileNet版本号。1表示MobileNet V1,2表示使用MobileNet V2。默认值为1。
  • alpha:较小的alpha会降低精度,但会提高性能。默认值为1.0。
model.classify(
  img: tf.Tensor3D | ImageData | HTMLImageElement |
      HTMLCanvasElement | HTMLVideoElement,
  topk?: number
)

参数:

  • img:进行分类的Tensor或image元素。
  • topk:要返回多少个Top概率。默认值为3。

借助于封装的JS库,在浏览器中使用MobileNets就相当简单了:

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.1"> </script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>
  </head>
  <body>
    <img id="img" src="cat.jpg"></img>
    <script>
      const img = document.getElementById('img');

      // Load the model.
      mobilenet.load().then(model => {
        // Classify the image.
        model.classify(img).then(predictions => {
          console.log('Predictions: ');
          console.log(predictions);
        });
      });
    </script>
  </body>
</html>

注意: 这里的js代码会去google storage 加载MobileNets的JSON格式模型,而由于一些不能说的原因,国内无法访问到,请自行翻墙。

这个示例写的比较简单,从浏览器控制台输出log,显示结果,在chrome浏览器中可以打开开发者工具查看:

加载json格式的MobileNets模型

使用封装好的JS对象确实方便,但使用自己训练的模型时,并没有人为我们提供封装对象。这个时候我们就要考虑自行加载模型,并进行推断。在JS世界,JSON是使用得非常普遍的数据交换格式。TensorFlow.js也采用JSON作为模型格式,也提供了工具进行转换。

本来这里想详细写一下如何加载json格式的MobileNets模型,但由于MobileNets的JS模型托管在Google服务器上,国内无法访问,所以这里先跳过这一步。在下一篇文章中我将说明如何从现有的TensorFlow模型转换为TensorFlow.js模型,并加载之,敬请关注!

以上示例有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。 另外,你也可以在浏览器中直接访问:ilego.club/ai/index.ht… ,直接体验浏览器中的机器学习。

参考文献:

  1. tensorflow官网

你还可以读

  1. 一步步提高手写数字的识别率(1)(2)(3)
  2. TensorFlow.js简介
  3. 浏览器中的手写数字识别

image