使用PyQt5与Tensorflow Hub实现图片的识别

424 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

为了使得项目相对有意义,本博客使用简单的QT界面进行实现,主要使用tensorflow hub完成图片的识别,tensorflow hub是Google提供的预训练模型的可选集合,其主要是机器学习模组打包函数库,帮开发者将tensorflow的训练模型打包成模组,方便使用者可以直接使用预训练权重完成某些任务。TensorFlow Hub是一个专注于可复用机器学习(也就是迁移学习transfer learning)的开源仓库与开源库。tfhub.dev仓库提供了许多预训练模型,例如word embeddings,图像分类模型等等。

Keras是一个TensorFlow的高阶API,它可以通过组成Keras Layer对象来建立深度学习模型。tensorflow_hub库提供了hub.KerasLayer类,我们可以提供一个SavedModel的URL或者文件地址来初始化hub.KerasLayer,然后它可以给我们提供SavedModel中的计算结果(包括预训练的权重)。

Hub包括了许多种不同用途的模块,比如文本分类、语句编码、图像分类、特征提取、使用GAN生成图像和视频分类,目前Google和DeepMind都致力于发布工作。下面我将通过代码来实际的应用tensorflowhub。

代码主要分为两个部分,ImageNet的1000类别我已经封装为python中的列表形式,供识别时调用,同时有需要的也可以在此直接取用。代码如下

第一部分,QT界面和分类识别代码,识别返回的类别代码通过列表的映射得到其中的真实类别,并将其绘制到图像中进行显示。

import sys
from PyQt5.QtWidgets import (QWidget, QHBoxLayout, QLabel, QApplication)
from PyQt5.QtWidgets import QFileDialog, QPushButton
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import PIL.Image as Image
import ImageNetClass


class Example (QWidget):
    def __init__(self):
        super ().__init__()
        self.initUI ()

    # 初始化GUI,通过自选图片方式将图片进行分类显示
    def initUI(self):
        # lbl = QLabel(self)
        lbl = QPushButton(self)
        lbl.setText("单击选择图片")
        lbl.setCheckable(True)
        lbl.move(10, 10)
        lbl.clicked[bool].connect(self.setColor)
        hbox = QHBoxLayout(self)
        hbox.addWidget(lbl)
        self.setLayout (hbox)
        self.move (300, 200)
        self.setWindowTitle ('选择图片')
        self.resize(400, 200)
        self.show ()

    # 检测图片中的物体种类
    def detect(self, path):
        classifier_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2"
        IMAGE_SHAPE = (224, 224)
        classifier = tf.keras.Sequential(
            [hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE + (3,))]
        )
        grace_hopper = Image.open(path[0]).resize(IMAGE_SHAPE)

        grace_hopper = np.array(grace_hopper) / 255.0
        result = classifier.predict(grace_hopper[np.newaxis, ...])
        predict_class = np.argmax(result[0], axis=-1)
        print(predict_class)
        # print(ImageNetClass.listClass[predict_class])
        plt.imshow(grace_hopper)
        plt.rcParams['font.sans-serif'] = ['SimHei']
        plt.rcParams['axes.unicode_minus'] = False
        plt.title(ImageNetClass.listClass[predict_class])
        plt.show()
        return ImageNetClass.listClass[predict_class]

    # 按钮的槽函数,用于打开图片获取图片路径
    def setColor(self, pressed):
        source = self.sender()
        path1 = QFileDialog.getOpenFileName(self, "打开图片", "", "*.jpg *.png *.jpeg *.bmp")
        str = self.detect(path1)
        print(str)


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = Example ()
    sys.exit (app.exec_())

第二部分,ImageNet对应的label,这里没放上,有1000行,数据量有点大,需要的私信我

下面是运行的部分截图 界面非常简洁,只包含一个按钮,单击按钮即可选择图片进行识别 1.PNG

识别猫的结果,主要分为两种猫的识别,下面分别是识别的结果 2.PNG

4.PNG

下面是对两种狗的识别

3.PNG

5.PNG