使用保存的模型识别图片

151 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

上一篇博客中使用了数据增强的方式来提高模型的训练效果,在模型训练完成后对模型进行了保存,方便后期调用模型实现图片的识别。

本博客将简单的叙述如何使用保存的模型进行图片分类,由于我们使用的CIFAR10进行训练,所以我们提前准备一张猫和一张狗的图片,然后加载模型看是否可以正确识别。 下面直接上代码

# 导入需要的头文件
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input
import numpy as np

# 加载图片,并将图片设置为与数据集同样大小和维度
img = image.load_img('cat2.jpg', target_size=(32, 32, 3))
# 将图片转为array格式
x = image.img_to_array(img)
# 主要用于扩展数组的形状,axis表示几,就在第几维度的地方添加数据,此处为0,表示在第0维度添加数据,完成后数据变为(1, 32, 32, 3)
x = np.expand_dims(x, axis=0)
# 将传入的数据进行处理,类似于数据的归一化操作
x = preprocess_input(x)
# 加载模型,此处说明一点,上一篇博客中代码保存模型时直接用model.save('detectCIFAR10-1.h5'),不使用保存为json和权重文件的方式,这样加载时非常方便
new_model = tf.keras.models.load_model('detectCIFAR10-1.h5')
# 预测类别,CIFAR10中0表示飞机,1表示汽车,2表示鸟,3表示猫,4表示鹿,5表示狗、6表示青蛙,7表示房子、8表示船、9表示卡车
pred = new_model.predict_classes(x)
# 输出预测的类别数字,按照上述的说明可以查看是否预测正确
preds = new_model.predict(x)
print(pred)

之前在最前面的博客中讲到过模型加载的方法,但是随着网络与数据集的不同,加载模型的方法也发生了变化,还有一种比较繁琐的权重和模型分离的加载方法,后面用到时再进行说明。运行上述代码,可以看到自己辛辛苦苦训练的网络终于能派上用场了,可以完成几个类别的图像分类了,但同样会发现网络的准确率非常感人,经常预测错误,不要慌,下一篇博客中我们讲解一种更深度的网络模型,到时候再将模型保存与此处的模型进行比较,看是否得到提升。下一篇博客向着VGG16出发!