前言
上篇文章简单介绍了TensorFlow的 Hello World 。用TensorFlow 完成了一个线性模型的预测。
今天我们来进行一个最简单的图像分类任务。
数据集介绍
我们使用的数据集是fashion-mnist。其github地址是: github.com/zalandorese…
数据集中包含了很多图像:
每一个训练数据都对应一个类别。 其Label的取值范围是 0 - 9 ,一共10类。每一类对应的实际含义如下表:
| Label | Description |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
数据集初探
我们开始正式编码,使用TensorFlow来进行这个图像分类任务。
首先加载数据集:
import tensorflow as tf
minst = tf.keras.datasets.fashion_mnist.load_data()
从数据集中获取到 训练集、训练标签、测试集、测试标签:
(training_images, training_labels), (test_images, test_labels) = minst
接下来看一下训练集中第一张图片是什么
import matplotlib.pyplot as plt
training_images[0]
我们发现是一个多维数组。
array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 13, 73, 0, 0, 1, 4, 0, 0, 0, 0, 1,
1, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
0, 36, 136, 127, 62, 54, 0, 0, 0, 1, 3, 4, 0,
0, 3],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,
0, 102, 204, 176, 134, 144, 123, 23, 0, 0, 0, 0, 12,
10, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 155, 236, 207, 178, 107, 156, 161, 109, 64, 23, 77, 130,
72, 15],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
69, 207, 223, 218, 216, 216, 163, 127, 121, 122, 146, 141, 88,
172, 66],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
200, 232, 232, 233, 229, 223, 223, 215, 213, 164, 127, 123, 196,
229, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
183, 225, 216, 223, 228, 235, 227, 224, 222, 224, 221, 223, 245,
173, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
193, 228, 218, 213, 198, 180, 212, 210, 211, 213, 223, 220, 243,
202, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 12,
219, 220, 212, 218, 192, 169, 227, 208, 218, 224, 212, 226, 197,
209, 52],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 99,
244, 222, 220, 218, 203, 198, 221, 215, 213, 222, 220, 245, 119,
167, 56],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 55,
236, 228, 230, 228, 240, 232, 213, 218, 223, 234, 217, 217, 209,
92, 0],
[ 0, 0, 1, 4, 6, 7, 2, 0, 0, 0, 0, 0, 237,
226, 217, 223, 222, 219, 222, 221, 216, 223, 229, 215, 218, 255,
77, 0],
[ 0, 3, 0, 0, 0, 0, 0, 0, 0, 62, 145, 204, 228,
207, 213, 221, 218, 208, 211, 218, 224, 223, 219, 215, 224, 244,
159, 0],
[ 0, 0, 0, 0, 18, 44, 82, 107, 189, 228, 220, 222, 217,
226, 200, 205, 211, 230, 224, 234, 176, 188, 250, 248, 233, 238,
215, 0],
[ 0, 57, 187, 208, 224, 221, 224, 208, 204, 214, 208, 209, 200,
159, 245, 193, 206, 223, 255, 255, 221, 234, 221, 211, 220, 232,
246, 0],
[ 3, 202, 228, 224, 221, 211, 211, 214, 205, 205, 205, 220, 240,
80, 150, 255, 229, 221, 188, 154, 191, 210, 204, 209, 222, 228,
225, 0],
[ 98, 233, 198, 210, 222, 229, 229, 234, 249, 220, 194, 215, 217,
241, 65, 73, 106, 117, 168, 219, 221, 215, 217, 223, 223, 224,
229, 29],
[ 75, 204, 212, 204, 193, 205, 211, 225, 216, 185, 197, 206, 198,
213, 240, 195, 227, 245, 239, 223, 218, 212, 209, 222, 220, 221,
230, 67],
[ 48, 203, 183, 194, 213, 197, 185, 190, 194, 192, 202, 214, 219,
221, 220, 236, 225, 216, 199, 206, 186, 181, 177, 172, 181, 205,
206, 115],
[ 0, 122, 219, 193, 179, 171, 183, 196, 204, 210, 213, 207, 211,
210, 200, 196, 194, 191, 195, 191, 198, 192, 176, 156, 167, 177,
210, 92],
[ 0, 0, 74, 189, 212, 191, 175, 172, 175, 181, 185, 188, 189,
188, 193, 198, 204, 209, 210, 210, 211, 188, 188, 194, 192, 216,
170, 0],
[ 2, 0, 0, 0, 66, 200, 222, 237, 239, 242, 246, 243, 244,
221, 220, 193, 191, 179, 182, 182, 181, 176, 166, 168, 99, 58,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 40, 61, 44, 72, 41, 35,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]], dtype=uint8)
接下来看看具体的维度:
training_images[0].shape
其结果是一个 28*28 的二位数组。
(28, 28)
下面以图像的形式来显示这张图片:
print(training_labels[0])
plt.imshow(training_images[0])
结果如下:其中9对应的是Ankle boot。
9
<matplotlib.image.AxesImage at 0x19eb8ff2310>
数据处理
对图像进行分类之前,我们先进行归一化:
training_images = training_images / 255
test_images = test_images/255
然后我们看看第一张图片的内容是否发生变化:
training_images[0]
我们看到第一张图片包含的数字,全部变成了0-1之间的数。这也符合我们归一化的需求:
array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.00392157, 0. , 0. , 0.05098039, 0.28627451, 0. , 0. , 0.00392157, 0.01568627, 0. , 0. , 0. , 0. , 0.00392157, 0.00392157, 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01176471, 0. , 0.14117647, 0.53333333, 0.49803922, 0.24313725, 0.21176471, 0. , 0. , 0. , 0.00392157, 0.01176471, 0.01568627, 0. , 0. , 0.01176471],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.02352941, 0. , 0.4 , 0.8 , 0.69019608, 0.5254902 , 0.56470588, 0.48235294, 0.09019608, 0. , 0. , 0. , 0. , 0.04705882, 0.03921569, 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.60784314, 0.9254902 , 0.81176471, 0.69803922, 0.41960784, 0.61176471, 0.63137255, 0.42745098, 0.25098039, 0.09019608, 0.30196078, 0.50980392, 0.28235294, 0.05882353],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.00392157, 0. , 0.27058824, 0.81176471, 0.8745098 , 0.85490196, 0.84705882, 0.84705882, 0.63921569, 0.49803922, 0.4745098 , 0.47843137, 0.57254902, 0.55294118, 0.34509804, 0.6745098 , 0.25882353],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.00392157, 0.00392157, 0.00392157, 0. , 0.78431373, 0.90980392, 0.90980392, 0.91372549, 0.89803922, 0.8745098 , 0.8745098 , 0.84313725, 0.83529412, 0.64313725, 0.49803922, 0.48235294, 0.76862745, 0.89803922, 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.71764706, 0.88235294, 0.84705882, 0.8745098 , 0.89411765, 0.92156863, 0.89019608, 0.87843137, 0.87058824, 0.87843137, 0.86666667, 0.8745098 , 0.96078431, 0.67843137, 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.75686275, 0.89411765, 0.85490196, 0.83529412, 0.77647059, 0.70588235, 0.83137255, 0.82352941, 0.82745098, 0.83529412, 0.8745098 , 0.8627451 , 0.95294118, 0.79215686, 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.00392157, 0.01176471, 0. , 0.04705882, 0.85882353, 0.8627451 , 0.83137255, 0.85490196, 0.75294118, 0.6627451 , 0.89019608, 0.81568627, 0.85490196, 0.87843137, 0.83137255, 0.88627451, 0.77254902, 0.81960784, 0.20392157],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.02352941, 0. , 0.38823529, 0.95686275, 0.87058824, 0.8627451 , 0.85490196, 0.79607843, 0.77647059, 0.86666667, 0.84313725, 0.83529412, 0.87058824, 0.8627451 , 0.96078431, 0.46666667, 0.65490196, 0.21960784],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01568627, 0. , 0. , 0.21568627, 0.9254902 , 0.89411765, 0.90196078, 0.89411765, 0.94117647, 0.90980392, 0.83529412, 0.85490196, 0.8745098 , 0.91764706, 0.85098039, 0.85098039, 0.81960784, 0.36078431, 0. ],
[0. , 0. , 0.00392157, 0.01568627, 0.02352941, 0.02745098, 0.00784314, 0. , 0. , 0. , 0. , 0. , 0.92941176, 0.88627451, 0.85098039, 0.8745098 , 0.87058824, 0.85882353, 0.87058824, 0.86666667, 0.84705882, 0.8745098 , 0.89803922, 0.84313725, 0.85490196, 1. , 0.30196078, 0. ],
[0. , 0.01176471, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.24313725, 0.56862745, 0.8 , 0.89411765, 0.81176471, 0.83529412, 0.86666667, 0.85490196, 0.81568627, 0.82745098, 0.85490196, 0.87843137, 0.8745098 , 0.85882353, 0.84313725, 0.87843137, 0.95686275, 0.62352941, 0. ],
[0. , 0. , 0. , 0. , 0.07058824, 0.17254902, 0.32156863, 0.41960784, 0.74117647, 0.89411765, 0.8627451 , 0.87058824, 0.85098039, 0.88627451, 0.78431373, 0.80392157, 0.82745098, 0.90196078, 0.87843137, 0.91764706, 0.69019608, 0.7372549 , 0.98039216, 0.97254902, 0.91372549, 0.93333333, 0.84313725, 0. ],
[0. , 0.22352941, 0.73333333, 0.81568627, 0.87843137, 0.86666667, 0.87843137, 0.81568627, 0.8 , 0.83921569, 0.81568627, 0.81960784, 0.78431373, 0.62352941, 0.96078431, 0.75686275, 0.80784314, 0.8745098 , 1. , 1. , 0.86666667, 0.91764706, 0.86666667, 0.82745098, 0.8627451 , 0.90980392, 0.96470588, 0. ],
[0.01176471, 0.79215686, 0.89411765, 0.87843137, 0.86666667, 0.82745098, 0.82745098, 0.83921569, 0.80392157, 0.80392157, 0.80392157, 0.8627451 , 0.94117647, 0.31372549, 0.58823529, 1. , 0.89803922, 0.86666667, 0.7372549 , 0.60392157, 0.74901961, 0.82352941, 0.8 , 0.81960784, 0.87058824, 0.89411765, 0.88235294, 0. ],
[0.38431373, 0.91372549, 0.77647059, 0.82352941, 0.87058824, 0.89803922, 0.89803922, 0.91764706, 0.97647059, 0.8627451 , 0.76078431, 0.84313725, 0.85098039, 0.94509804, 0.25490196, 0.28627451, 0.41568627, 0.45882353, 0.65882353, 0.85882353, 0.86666667, 0.84313725, 0.85098039, 0.8745098 , 0.8745098 , 0.87843137, 0.89803922, 0.11372549],
[0.29411765, 0.8 , 0.83137255, 0.8 , 0.75686275, 0.80392157, 0.82745098, 0.88235294, 0.84705882, 0.7254902 , 0.77254902, 0.80784314, 0.77647059, 0.83529412, 0.94117647, 0.76470588, 0.89019608, 0.96078431, 0.9372549 , 0.8745098 , 0.85490196, 0.83137255, 0.81960784, 0.87058824, 0.8627451 , 0.86666667, 0.90196078, 0.2627451 ],
[0.18823529, 0.79607843, 0.71764706, 0.76078431, 0.83529412, 0.77254902, 0.7254902 , 0.74509804, 0.76078431, 0.75294118, 0.79215686, 0.83921569, 0.85882353, 0.86666667, 0.8627451 , 0.9254902 , 0.88235294, 0.84705882, 0.78039216, 0.80784314, 0.72941176, 0.70980392, 0.69411765, 0.6745098 , 0.70980392, 0.80392157, 0.80784314, 0.45098039],
[0. , 0.47843137, 0.85882353, 0.75686275, 0.70196078, 0.67058824, 0.71764706, 0.76862745, 0.8 , 0.82352941, 0.83529412, 0.81176471, 0.82745098, 0.82352941, 0.78431373, 0.76862745, 0.76078431, 0.74901961, 0.76470588, 0.74901961, 0.77647059, 0.75294118, 0.69019608, 0.61176471, 0.65490196, 0.69411765, 0.82352941, 0.36078431],
[0. , 0. , 0.29019608, 0.74117647, 0.83137255, 0.74901961, 0.68627451, 0.6745098 , 0.68627451, 0.70980392, 0.7254902 , 0.7372549 , 0.74117647, 0.7372549 , 0.75686275, 0.77647059, 0.8 , 0.81960784, 0.82352941, 0.82352941, 0.82745098, 0.7372549 , 0.7372549 , 0.76078431, 0.75294118, 0.84705882, 0.66666667, 0. ],
[0.00784314, 0. , 0. , 0. , 0.25882353, 0.78431373, 0.87058824, 0.92941176, 0.9372549 , 0.94901961, 0.96470588, 0.95294118, 0.95686275, 0.86666667, 0.8627451 , 0.75686275, 0.74901961, 0.70196078, 0.71372549, 0.71372549, 0.70980392, 0.69019608, 0.65098039, 0.65882353, 0.38823529, 0.22745098, 0. , 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.15686275, 0.23921569, 0.17254902, 0.28235294, 0.16078431, 0.1372549 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]])
建立模型
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation = tf.nn.relu),
tf.keras.layers.Dense(10, activation = tf.nn.softmax)
]
)
这里我们建立了一个神经网络模型, 模型中有 3 层, 第一层:使用 Flatten 把2828的图片平铺成1维数组,也就是7681 的数据。 第二层:全连接层,有128个神经元。激活函数用的是 Relu。 第三层:全连接层,也是结果层,有10个神经元。这是因为我们需要进行的是一个多分类任务,而softmax就是专门用来进行多分类任务的。所以激活函数用的是softmax。
接下来编译模型并训练:
model.compile(
optimizer = tf.optimizers.Adam(),
loss="sparse_categorical_crossentropy",
metrics = ['accuracy']
)
model.fit(training_images, training_labels, epochs = 5)
epochs = 5 是训练的轮数。
结果如下:
Epoch 1/5
1875/1875 [==============================] - 2s 692us/step - loss: 0.5020 - accuracy: 0.8235
Epoch 2/5
1875/1875 [==============================] - 1s 680us/step - loss: 0.3727 - accuracy: 0.8664
Epoch 3/5
1875/1875 [==============================] - 1s 735us/step - loss: 0.3389 - accuracy: 0.8765
Epoch 4/5
1875/1875 [==============================] - 1s 672us/step - loss: 0.3146 - accuracy: 0.8853
Epoch 5/5
1875/1875 [==============================] - 1s 686us/step - loss: 0.2960 - accuracy: 0.8910
训练完后,我们来试一试模型的准确率:
model.evaluate(test_images, test_labels)
结果大概是0.8668
313/313 [==============================] - 0s 544us/step - loss: 0.3702 - accuracy: 0.8668
如果想获得更精准的数据,可以适当增大epochs。进行更多轮的训练。
我们接下来看看测试集的第一张图片,然后验证算法是否准确:
classifications = model.predict(test_images)
print(classifications[0])
其结果是一个长度为10的数组,每个数字对应的是 0-9 类的概率。这里最大的是 index 为 9 的数字。所以预测结果应该是9.
[3.6468668e-06 5.4612696e-07 4.5575359e-08 2.6194398e-08 2.4610272e-06
2.0921556e-02 4.4843373e-06 5.3708768e-01 9.6340482e-05 4.4188327e-01]
我们看一下其真实的label:
test_labels[0]
结果如下:
9
和我们预测的一致。
后记
到这里 一个简单的图片多分类任务就完成了。但是实际结果发生了过拟合。测试集的预测准确率是86%。训练集是89%。