TensorFlow 之 计算机视觉入门--用 CNN 来改进 fashion-mnist 多分类任务

352 阅读3分钟

前言

上篇文章介绍了如何使用TensorFlow来搭建神经网络模型,对fashion-mnist数据集进行多分类的预测。

上篇文章的思路,是把(28, 28)的图片展开成 (768, 1)的一维数据,然后进行全连接层搭建模型,

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation = tf.nn.relu),
    tf.keras.layers.Dense(10, activation = tf.nn.softmax)
])

从上篇文章的模型代码中,我们可以看到,只使用了全连接层。最后的准确率大概是 86%~87%之间。

那么有没有方法再提高一些准确率呢,答案是有的,今天我们来使用卷积神经网络来提升准确率。

前情回顾

和上次一样,我们这次使用的数据集依然是fashion-mnist。是一个图片多分类的任务。包含衣服、鞋子等各种物品。

image.png

每一个训练数据都对应一个类别。 其Label的取值范围是 0 - 9 ,一共10类。每一类对应的实际含义如下表:

LabelDescription
0T-shirt/top
1Trouser
2Pullover
3Dress
4Coat
5Sandal
6Shirt
7Sneaker
8Bag
9Ankle boot

模型搭建

导入数据
import tensorflow as tf
minst = tf.keras.datasets.fashion_mnist.load_data()
数据集拆分

留一部分当训练集,一部分当测试集

(training_images, training_labels), (test_images, test_labels) = minst
归一化

归一化的作用是在梯度下降的时候,对每个特征都能有效的稳定的求导。

training_images = training_images / 255
test_images = test_images/255
维度调整

我们把数据集变成了 28281的维度,1表示1个颜色通道,由于这里的图像是灰度图,所以只有1个通道。 在彩色图片的情况下,就可能是R,G,B三个维度了。

training_images = training_images.reshape(60000, 28, 28, 1)
test_images = test_images.reshape(10000, 28, 28, 1)
建立模型

首先建立模型。建立后的模型,里面的输入和输出,我们待会儿通过一条命令来查看。

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPooling2D(2, 2))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(2, 2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
编译模型

然后是编译模型,这里 loss使用的是sparse_categorical_crossentropy,一般可用在多分类任务里。 optimizer使用的是Adam。

model.compile(
    optimizer = tf.optimizers.Adam(),
    loss="sparse_categorical_crossentropy",
    metrics=['accuracy']
)
模型概况
model.summary()

我们可以看到每一层的类别,以及的输入和输出的维度。最后一层softmax进行10分类。

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_5 (Conv2D)           (None, 26, 26, 64)        640       
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 13, 13, 64)       0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 11, 11, 64)        36928     
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 flatten_5 (Flatten)         (None, 1600)              0         
                                                                 
 dense_8 (Dense)             (None, 128)               204928    
                                                                 
 dense_9 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 243,786
Trainable params: 243,786
Non-trainable params: 0
模型训练
model.fit(training_images, training_labels, epochs = 50)

下面是训练过程,我么可以发现准确率逐渐提高到了99.45%。

Epoch 1/50
1875/1875 [==============================] - 7s 3ms/step - loss: 0.4372 - accuracy: 0.8412
Epoch 2/50
1875/1875 [==============================] - 6s 3ms/step - loss: 0.2946 - accuracy: 0.8912
Epoch 3/50
...
...
Epoch 48/50
1875/1875 [==============================] - 6s 3ms/step - loss: 0.0173 - accuracy: 0.9945
Epoch 49/50
1875/1875 [==============================] - 6s 3ms/step - loss: 0.0158 - accuracy: 0.9952
Epoch 50/50
1875/1875 [==============================] - 6s 3ms/step - loss: 0.0170 - accuracy: 0.9945
结果预测
model.evaluate(test_images, test_labels)

结果如下:

1,我们发现最后的准确率达到了90.6%,比只使用全连接层,多了4个点的准确率。

2,训练集的准确率达到了99%以上了。但是测试集只有90.6%,这说明发生了过拟合。

313/313 [==============================] - 1s 2ms/step - loss: 0.9121 - accuracy: 0.9060

后记

有在机器学习和深度学习的小伙伴,可以留言和点赞,我们一起交流。