本文已参与「新人创作礼」活动,一起开启掘金创作之路。
CIFAR-10一共包含10个类别:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。
下面直接上代码
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, optimizers
import matplotlib.pyplot as plt
# 定义网络参数
EPOCH = 30
BATCH_SIZE = 128
# 输出带进度条的日志信息
VERBOSE = 1
# 优化器
OPTIMIZER = optimizers.RMSprop()
# 划分比例
VALIDATION_SPLIT = 0.2
# 输入图片大小
img_w, img_h = 32, 32
# 输入图片的通道数,区别于手写数据集
img_channels = 3
input_shape = (img_w, img_h, img_channels)
# 类别数
num_class = 10
# 加载数据集
(train_x, train_y), (test_x, test_y) = datasets.cifar10.load_data()
# 查看数据集图片与标签内容
# plt.imshow(train_x[10])
# plt.show()
# print(train_y[10])
train_y = tf.squeeze(train_y)
train_y = tf.one_hot(train_y, depth=10)
test_y = tf.squeeze(test_y)
test_y = tf.one_hot(test_y, depth=10)
# 定义网络结构
def NetStruct(in_shape, in_class):
# 创建容器
model = models.Sequential()
# 在容器中加入网络层
# 卷积层,32个filter,大小为3x3
model.add(layers.Convolution2D(32, (3, 3), activation='relu', input_shape=in_shape))
# 最大池化
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
# dropout随机失活一部分神经元
model.add(layers.Dropout(0.3))
# 将输入层的数据压成一维的数据,一般用再卷积层和全连接层之间(因为全连接层只能接收一维数据,而卷积层可以处理二维数据,就是全连接层处理的是向量,而卷积层处理的是矩阵)
model.add(layers.Flatten())
# 全连接层,512个单元,使用relu激活
model.add(layers.Dense(512, activation='relu'))
# 随机失活一部分神经元
model.add(layers.Dropout(0.5))
# 带有十个类别的输出的softmax层
model.add(layers.Dense(in_class, activation='softmax'))
return model
# 保存tensorboard
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./blog07/')]
model = NetStruct(in_shape=input_shape, in_class=num_class)
model.compile(loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=['accuracy'])
# 喂数据
model.fit(train_x, train_y, batch_size=BATCH_SIZE, epochs=EPOCH, validation_split=VALIDATION_SPLIT, verbose=VERBOSE, callbacks=callbacks)
score = model.evaluate(test_x, test_y, batch_size=BATCH_SIZE, verbose=VERBOSE)
print("Test score:", score[0])
print("Test accuracy:", score[1])
这份代码很简单,网络结构也非常简单,如下图所示
正是因为相对简单,所以也带来了一个问题,如下面的loss与accuracy图所示
在训练了大约12个epoch之后,正确率不再增加,而且损失函数也不再下降,所以整个网络的性能也肯定好不到哪里去。相比于手写数据集的识别,CIFAR10的识别为什么突然变得这么差了呢?其实,这跟数据集有非常大的关系,手写数据集相对较为简单,图像也均为灰度图,而CIFAR10全是3通道RGB图像,而且图像的背景也较为复杂,在特征提取的时候难免导致效果不好,这也是随着学习的深入我们需要解决的问题,在下一篇博客中我将在网络性能的改善方面提供一些方法,供大家参考。
最后,附上悲惨的训练结果