本文已参与「新人创作礼」活动,一起开启掘金创作之路。
为了使得项目相对有意义,本博客使用简单的QT界面进行实现,主要使用tensorflow hub完成图片的识别,tensorflow hub是Google提供的预训练模型的可选集合,其主要是机器学习模组打包函数库,帮开发者将tensorflow的训练模型打包成模组,方便使用者可以直接使用预训练权重完成某些任务。TensorFlow Hub是一个专注于可复用机器学习(也就是迁移学习transfer learning)的开源仓库与开源库。tfhub.dev仓库提供了许多预训练模型,例如word embeddings,图像分类模型等等。
Keras是一个TensorFlow的高阶API,它可以通过组成Keras Layer对象来建立深度学习模型。tensorflow_hub库提供了hub.KerasLayer类,我们可以提供一个SavedModel的URL或者文件地址来初始化hub.KerasLayer,然后它可以给我们提供SavedModel中的计算结果(包括预训练的权重)。
Hub包括了许多种不同用途的模块,比如文本分类、语句编码、图像分类、特征提取、使用GAN生成图像和视频分类,目前Google和DeepMind都致力于发布工作。下面我将通过代码来实际的应用tensorflowhub。
代码主要分为两个部分,ImageNet的1000类别我已经封装为python中的列表形式,供识别时调用,同时有需要的也可以在此直接取用。代码如下
第一部分,QT界面和分类识别代码,识别返回的类别代码通过列表的映射得到其中的真实类别,并将其绘制到图像中进行显示。
import sys
from PyQt5.QtWidgets import (QWidget, QHBoxLayout, QLabel, QApplication)
from PyQt5.QtWidgets import QFileDialog, QPushButton
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import PIL.Image as Image
import ImageNetClass
class Example (QWidget):
def __init__(self):
super ().__init__()
self.initUI ()
# 初始化GUI,通过自选图片方式将图片进行分类显示
def initUI(self):
# lbl = QLabel(self)
lbl = QPushButton(self)
lbl.setText("单击选择图片")
lbl.setCheckable(True)
lbl.move(10, 10)
lbl.clicked[bool].connect(self.setColor)
hbox = QHBoxLayout(self)
hbox.addWidget(lbl)
self.setLayout (hbox)
self.move (300, 200)
self.setWindowTitle ('选择图片')
self.resize(400, 200)
self.show ()
# 检测图片中的物体种类
def detect(self, path):
classifier_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2"
IMAGE_SHAPE = (224, 224)
classifier = tf.keras.Sequential(
[hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE + (3,))]
)
grace_hopper = Image.open(path[0]).resize(IMAGE_SHAPE)
grace_hopper = np.array(grace_hopper) / 255.0
result = classifier.predict(grace_hopper[np.newaxis, ...])
predict_class = np.argmax(result[0], axis=-1)
print(predict_class)
# print(ImageNetClass.listClass[predict_class])
plt.imshow(grace_hopper)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title(ImageNetClass.listClass[predict_class])
plt.show()
return ImageNetClass.listClass[predict_class]
# 按钮的槽函数,用于打开图片获取图片路径
def setColor(self, pressed):
source = self.sender()
path1 = QFileDialog.getOpenFileName(self, "打开图片", "", "*.jpg *.png *.jpeg *.bmp")
str = self.detect(path1)
print(str)
if __name__ == '__main__':
app = QApplication(sys.argv)
ex = Example ()
sys.exit (app.exec_())
第二部分,ImageNet对应的label,这里没放上,有1000行,数据量有点大,需要的私信我
下面是运行的部分截图
界面非常简洁,只包含一个按钮,单击按钮即可选择图片进行识别
识别猫的结果,主要分为两种猫的识别,下面分别是识别的结果
下面是对两种狗的识别