如何将 TensorFlow/Keras 应用到实际项目中?| 七日打卡

925 阅读4分钟

前言

机器学习的热度已经很久了,很多大厂内部也在大量应用,比如UI稿生成代码,或是我们早已体验到的人脸识别等等......作为前端仔的我,在过去的几个月中,遇到了可以应用 TensorFlow 的需求点:以图搜图。在此分享一下。

需求分析

要实现的功能很简单,在有限的原始图片集合中,用户输入一张图片,以图搜图服务返回特征值匹配度高的图片集合。

实现步骤

整个实现步骤分为以下步骤:

  • 收集图片数据
  • 遍历图片数据,计算特征值并保存
  • 获取用户输入图片,识别特征值并与之前的特征值比对
  • 返回数据给前端

收集图片数据

这里如何收集就不展开说了,你可以用任何方式获取,而在我们的业务中,graphql 接口会返回所有图片。但我们的业务图片格式是 svg,因此这里需要做一次转换,将 svg -> jpeg 。这里我用 nodejssharp 包 来完成图片类型的转换。


async function saveICON(icons) {
  const fileType = [];
  async function saveImg(fileName) {
    return sharp(`./svg/${fileName}`)
      .resize(32, 32)
      .flatten({ background: '#ffffff' })
      .webp({ quality: 100 })
      .jpeg()
      .toFile(`***${fileName}.jpeg`);
  }
  const proAll = icons.map(async (item) => {
    try {
      const code = unescape(item.source_code);
      fs.writeFileSync(`***`, code);
      return saveImg(fileName);
    } catch (e) {
      console.error('错误', e);
    }
  });
  return Promise.all(proAll);
}

计算特征值并保存

计算特征值

这里是我们的重点,当拿到所有的 jepg 图片后,就需要计算其特征值了,这里介绍一下 Keras

Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。能够以最小的时延把你的想法转换为实验结果,是做好研究的关健

因为 TensorFlow 的入门曲线过于陡峭,所以采用它的高级 API Keras 来入门再合适不过。

比如如何用预学习模型获取图片特征值:

# 代码来自:https://keras.io/zh/applications/#vgg16_1
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import numpy as np

model = VGG16(weights='imagenet', include_top=False)

img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

features = model.predict(x)

代码 DEMO 官网都给出了,我们需要做的仅仅是用 Python 遍历图片目录,然后提取特征值并保存就行了。运行一遍方法,提取出来的特征值大概是这个样子,它是一个矩阵:

[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 4.25918460e-01 9.74334311e-03
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 4.47806902e-03 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.80634430e-03
 3.21221575e-02 0.00000000e+00 0.00000000e+00 2.90707089e-02
 6.12684526e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
 6.64397283e-03 ...]

遍历图片列表

翻一翻 Python 文档就知道如何获取文件目录,并遍历,然后将上面的提取特征值代码封装为 vggExtractFeat 函数,就这样实现了遍历+提取。

for i, img_path in enumerate(img_list):
        try:
            norm_feat = model.vggExtractFeat(img_path)
            img_name = os.path.split(img_path)[1]
        except IOError:
            print("Error: 没有找到文件或读取文件失败")

保存和获取特征值

特征值拿到了,该如何保存呢,Python 提供了 h5py 包,他可以将数据存在文件数据中。

h5f = h5py.File('feat.h5', 'w')
# 创建数据库
h5f.create_dataset('feats', data=feats)
h5f.create_dataset('names', data=names)

这样,将特征值数组和文件名称存在 feat.h5 文件中。那如何获取呢?

h5f = h5py.File(index, 'r')
feats = h5f['dataset_1'][:]
imgNames = h5f['dataset_2'][:]

这样我们就完成了计算特征值并保存

比对用户图片并返回图片列表

这里如何实现收集用户输入的搜索图就不展开了,我是起了一个 node 服务,前端通过post请求将图片base64格式发过来。
主要讲如何比对图片特征值。

当nodejs 将图片数据转发给 python 后,同样的调用 vggExtractFeat 方法获取图片特征值。然后进行对比。

前面讲了特征值就是一个矩阵,并且是 1 X N 矩阵。如何对比两个矩阵相似?最简单的方式就是矩阵乘积(市面上有开源工具来优化特征值对比,这里只提供一个最简单的方式)。
举个例子: 假设三个矩阵

A:[.9,0,.9,0]。 B:[.9,0,.9,.9] C: [0,.9,0,.9]

试着分别用A对其他两个不同的矩阵做乘积
[.9,0,.9,0] x [.9,0,.9,.9] = 1.62 [.9,0,.9,0] x [0,.9,0,.9] = 0

事实上,乘积的结果大小,就是匹配度大小。所以B矩阵和A更相似。

通过 python 的 np 包,可以计算两个矩阵的乘积。

np.dot(mix1, mix2),之后将乘积从大到小排序就可以拿到 特征值匹配度较高的图片了。

然后 Python 将图片列表返回给 Nodejs,Nodejs 再将图片返回给前端。至此就完成了以图搜图的全部流程。

参考