前言
机器学习的热度已经很久了,很多大厂内部也在大量应用,比如UI稿生成代码,或是我们早已体验到的人脸识别等等......作为前端仔的我,在过去的几个月中,遇到了可以应用 TensorFlow 的需求点:以图搜图。在此分享一下。
需求分析
要实现的功能很简单,在有限的原始图片集合中,用户输入一张图片,以图搜图服务返回特征值匹配度高的图片集合。
实现步骤
整个实现步骤分为以下步骤:
- 收集图片数据
- 遍历图片数据,计算特征值并保存
- 获取用户输入图片,识别特征值并与之前的特征值比对
- 返回数据给前端
收集图片数据
这里如何收集就不展开说了,你可以用任何方式获取,而在我们的业务中,graphql 接口会返回所有图片。但我们的业务图片格式是 svg,因此这里需要做一次转换,将 svg -> jpeg 。这里我用 nodejs 的 sharp 包 来完成图片类型的转换。
async function saveICON(icons) {
const fileType = [];
async function saveImg(fileName) {
return sharp(`./svg/${fileName}`)
.resize(32, 32)
.flatten({ background: '#ffffff' })
.webp({ quality: 100 })
.jpeg()
.toFile(`***${fileName}.jpeg`);
}
const proAll = icons.map(async (item) => {
try {
const code = unescape(item.source_code);
fs.writeFileSync(`***`, code);
return saveImg(fileName);
} catch (e) {
console.error('错误', e);
}
});
return Promise.all(proAll);
}
计算特征值并保存
计算特征值
这里是我们的重点,当拿到所有的 jepg 图片后,就需要计算其特征值了,这里介绍一下 Keras。
Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。能够以最小的时延把你的想法转换为实验结果,是做好研究的关健
因为 TensorFlow 的入门曲线过于陡峭,所以采用它的高级 API Keras 来入门再合适不过。
比如如何用预学习模型获取图片特征值:
# 代码来自:https://keras.io/zh/applications/#vgg16_1
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import numpy as np
model = VGG16(weights='imagenet', include_top=False)
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
features = model.predict(x)
代码 DEMO 官网都给出了,我们需要做的仅仅是用 Python 遍历图片目录,然后提取特征值并保存就行了。运行一遍方法,提取出来的特征值大概是这个样子,它是一个矩阵:
[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 4.25918460e-01 9.74334311e-03
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
4.47806902e-03 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 2.80634430e-03
3.21221575e-02 0.00000000e+00 0.00000000e+00 2.90707089e-02
6.12684526e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
6.64397283e-03 ...]
遍历图片列表
翻一翻 Python 文档就知道如何获取文件目录,并遍历,然后将上面的提取特征值代码封装为 vggExtractFeat 函数,就这样实现了遍历+提取。
for i, img_path in enumerate(img_list):
try:
norm_feat = model.vggExtractFeat(img_path)
img_name = os.path.split(img_path)[1]
except IOError:
print("Error: 没有找到文件或读取文件失败")
保存和获取特征值
特征值拿到了,该如何保存呢,Python 提供了 h5py 包,他可以将数据存在文件数据中。
h5f = h5py.File('feat.h5', 'w')
# 创建数据库
h5f.create_dataset('feats', data=feats)
h5f.create_dataset('names', data=names)
这样,将特征值数组和文件名称存在 feat.h5 文件中。那如何获取呢?
h5f = h5py.File(index, 'r')
feats = h5f['dataset_1'][:]
imgNames = h5f['dataset_2'][:]
这样我们就完成了计算特征值并保存
比对用户图片并返回图片列表
这里如何实现收集用户输入的搜索图就不展开了,我是起了一个 node 服务,前端通过post请求将图片base64格式发过来。
主要讲如何比对图片特征值。
当nodejs 将图片数据转发给 python 后,同样的调用 vggExtractFeat 方法获取图片特征值。然后进行对比。
前面讲了特征值就是一个矩阵,并且是 1 X N 矩阵。如何对比两个矩阵相似?最简单的方式就是矩阵乘积(市面上有开源工具来优化特征值对比,这里只提供一个最简单的方式)。
举个例子:
假设三个矩阵
A:[.9,0,.9,0]。
B:[.9,0,.9,.9]
C: [0,.9,0,.9]
试着分别用A对其他两个不同的矩阵做乘积
[.9,0,.9,0] x [.9,0,.9,.9] = 1.62
[.9,0,.9,0] x [0,.9,0,.9] = 0
事实上,乘积的结果大小,就是匹配度大小。所以B矩阵和A更相似。
通过 python 的 np 包,可以计算两个矩阵的乘积。
np.dot(mix1, mix2),之后将乘积从大到小排序就可以拿到 特征值匹配度较高的图片了。
然后 Python 将图片列表返回给 Nodejs,Nodejs 再将图片返回给前端。至此就完成了以图搜图的全部流程。