TensorFlow 之 计算机视觉入门 fashion-mnist 多分类任务

1,724 阅读9分钟

前言

上篇文章简单介绍了TensorFlow的 Hello World 。用TensorFlow 完成了一个线性模型的预测。

今天我们来进行一个最简单的图像分类任务。

数据集介绍

我们使用的数据集是fashion-mnist。其github地址是: github.com/zalandorese…

数据集中包含了很多图像:

image.png

每一个训练数据都对应一个类别。 其Label的取值范围是 0 - 9 ,一共10类。每一类对应的实际含义如下表:

LabelDescription
0T-shirt/top
1Trouser
2Pullover
3Dress
4Coat
5Sandal
6Shirt
7Sneaker
8Bag
9Ankle boot

数据集初探

我们开始正式编码,使用TensorFlow来进行这个图像分类任务。

首先加载数据集:

import tensorflow as tf
minst = tf.keras.datasets.fashion_mnist.load_data()

从数据集中获取到 训练集、训练标签、测试集、测试标签:

(training_images, training_labels), (test_images, test_labels)  = minst

接下来看一下训练集中第一张图片是什么

import matplotlib.pyplot as plt
training_images[0]

我们发现是一个多维数组。

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,
          0,   0,  13,  73,   0,   0,   1,   4,   0,   0,   0,   0,   1,
          1,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,
          0,  36, 136, 127,  62,  54,   0,   0,   0,   1,   3,   4,   0,
          0,   3],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   6,
          0, 102, 204, 176, 134, 144, 123,  23,   0,   0,   0,   0,  12,
         10,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0, 155, 236, 207, 178, 107, 156, 161, 109,  64,  23,  77, 130,
         72,  15],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   0,
         69, 207, 223, 218, 216, 216, 163, 127, 121, 122, 146, 141,  88,
        172,  66],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   0,
        200, 232, 232, 233, 229, 223, 223, 215, 213, 164, 127, 123, 196,
        229,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
        183, 225, 216, 223, 228, 235, 227, 224, 222, 224, 221, 223, 245,
        173,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
        193, 228, 218, 213, 198, 180, 212, 210, 211, 213, 223, 220, 243,
        202,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   3,   0,  12,
        219, 220, 212, 218, 192, 169, 227, 208, 218, 224, 212, 226, 197,
        209,  52],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   6,   0,  99,
        244, 222, 220, 218, 203, 198, 221, 215, 213, 222, 220, 245, 119,
        167,  56],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   4,   0,   0,  55,
        236, 228, 230, 228, 240, 232, 213, 218, 223, 234, 217, 217, 209,
         92,   0],
       [  0,   0,   1,   4,   6,   7,   2,   0,   0,   0,   0,   0, 237,
        226, 217, 223, 222, 219, 222, 221, 216, 223, 229, 215, 218, 255,
         77,   0],
       [  0,   3,   0,   0,   0,   0,   0,   0,   0,  62, 145, 204, 228,
        207, 213, 221, 218, 208, 211, 218, 224, 223, 219, 215, 224, 244,
        159,   0],
       [  0,   0,   0,   0,  18,  44,  82, 107, 189, 228, 220, 222, 217,
        226, 200, 205, 211, 230, 224, 234, 176, 188, 250, 248, 233, 238,
        215,   0],
       [  0,  57, 187, 208, 224, 221, 224, 208, 204, 214, 208, 209, 200,
        159, 245, 193, 206, 223, 255, 255, 221, 234, 221, 211, 220, 232,
        246,   0],
       [  3, 202, 228, 224, 221, 211, 211, 214, 205, 205, 205, 220, 240,
         80, 150, 255, 229, 221, 188, 154, 191, 210, 204, 209, 222, 228,
        225,   0],
       [ 98, 233, 198, 210, 222, 229, 229, 234, 249, 220, 194, 215, 217,
        241,  65,  73, 106, 117, 168, 219, 221, 215, 217, 223, 223, 224,
        229,  29],
       [ 75, 204, 212, 204, 193, 205, 211, 225, 216, 185, 197, 206, 198,
        213, 240, 195, 227, 245, 239, 223, 218, 212, 209, 222, 220, 221,
        230,  67],
       [ 48, 203, 183, 194, 213, 197, 185, 190, 194, 192, 202, 214, 219,
        221, 220, 236, 225, 216, 199, 206, 186, 181, 177, 172, 181, 205,
        206, 115],
       [  0, 122, 219, 193, 179, 171, 183, 196, 204, 210, 213, 207, 211,
        210, 200, 196, 194, 191, 195, 191, 198, 192, 176, 156, 167, 177,
        210,  92],
       [  0,   0,  74, 189, 212, 191, 175, 172, 175, 181, 185, 188, 189,
        188, 193, 198, 204, 209, 210, 210, 211, 188, 188, 194, 192, 216,
        170,   0],
       [  2,   0,   0,   0,  66, 200, 222, 237, 239, 242, 246, 243, 244,
        221, 220, 193, 191, 179, 182, 182, 181, 176, 166, 168,  99,  58,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  40,  61,  44,  72,  41,  35,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)

接下来看看具体的维度:

training_images[0].shape

其结果是一个 28*28 的二位数组。

(28, 28)

下面以图像的形式来显示这张图片:

print(training_labels[0])
plt.imshow(training_images[0])

结果如下:其中9对应的是Ankle boot。

9
<matplotlib.image.AxesImage at 0x19eb8ff2310>

image.png

数据处理

对图像进行分类之前,我们先进行归一化:

training_images = training_images / 255
test_images = test_images/255

然后我们看看第一张图片的内容是否发生变化:

training_images[0]

我们看到第一张图片包含的数字,全部变成了0-1之间的数。这也符合我们归一化的需求:

array([[0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.00392157, 0.        , 0.        ,        0.05098039, 0.28627451, 0.        , 0.        , 0.00392157,        0.01568627, 0.        , 0.        , 0.        , 0.        ,        0.00392157, 0.00392157, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.01176471, 0.        , 0.14117647,        0.53333333, 0.49803922, 0.24313725, 0.21176471, 0.        ,        0.        , 0.        , 0.00392157, 0.01176471, 0.01568627,        0.        , 0.        , 0.01176471],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.02352941, 0.        , 0.4       ,        0.8       , 0.69019608, 0.5254902 , 0.56470588, 0.48235294,        0.09019608, 0.        , 0.        , 0.        , 0.        ,        0.04705882, 0.03921569, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.60784314,        0.9254902 , 0.81176471, 0.69803922, 0.41960784, 0.61176471,        0.63137255, 0.42745098, 0.25098039, 0.09019608, 0.30196078,        0.50980392, 0.28235294, 0.05882353],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.00392157, 0.        , 0.27058824, 0.81176471,        0.8745098 , 0.85490196, 0.84705882, 0.84705882, 0.63921569,        0.49803922, 0.4745098 , 0.47843137, 0.57254902, 0.55294118,        0.34509804, 0.6745098 , 0.25882353],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.00392157,        0.00392157, 0.00392157, 0.        , 0.78431373, 0.90980392,        0.90980392, 0.91372549, 0.89803922, 0.8745098 , 0.8745098 ,        0.84313725, 0.83529412, 0.64313725, 0.49803922, 0.48235294,        0.76862745, 0.89803922, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.71764706, 0.88235294,        0.84705882, 0.8745098 , 0.89411765, 0.92156863, 0.89019608,        0.87843137, 0.87058824, 0.87843137, 0.86666667, 0.8745098 ,        0.96078431, 0.67843137, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.75686275, 0.89411765,        0.85490196, 0.83529412, 0.77647059, 0.70588235, 0.83137255,        0.82352941, 0.82745098, 0.83529412, 0.8745098 , 0.8627451 ,        0.95294118, 0.79215686, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.00392157,        0.01176471, 0.        , 0.04705882, 0.85882353, 0.8627451 ,        0.83137255, 0.85490196, 0.75294118, 0.6627451 , 0.89019608,        0.81568627, 0.85490196, 0.87843137, 0.83137255, 0.88627451,        0.77254902, 0.81960784, 0.20392157],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.02352941, 0.        , 0.38823529, 0.95686275, 0.87058824,        0.8627451 , 0.85490196, 0.79607843, 0.77647059, 0.86666667,        0.84313725, 0.83529412, 0.87058824, 0.8627451 , 0.96078431,        0.46666667, 0.65490196, 0.21960784],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.01568627,        0.        , 0.        , 0.21568627, 0.9254902 , 0.89411765,        0.90196078, 0.89411765, 0.94117647, 0.90980392, 0.83529412,        0.85490196, 0.8745098 , 0.91764706, 0.85098039, 0.85098039,        0.81960784, 0.36078431, 0.        ],
       [0.        , 0.        , 0.00392157, 0.01568627, 0.02352941,        0.02745098, 0.00784314, 0.        , 0.        , 0.        ,        0.        , 0.        , 0.92941176, 0.88627451, 0.85098039,        0.8745098 , 0.87058824, 0.85882353, 0.87058824, 0.86666667,        0.84705882, 0.8745098 , 0.89803922, 0.84313725, 0.85490196,        1.        , 0.30196078, 0.        ],
       [0.        , 0.01176471, 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.24313725,        0.56862745, 0.8       , 0.89411765, 0.81176471, 0.83529412,        0.86666667, 0.85490196, 0.81568627, 0.82745098, 0.85490196,        0.87843137, 0.8745098 , 0.85882353, 0.84313725, 0.87843137,        0.95686275, 0.62352941, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.07058824,        0.17254902, 0.32156863, 0.41960784, 0.74117647, 0.89411765,        0.8627451 , 0.87058824, 0.85098039, 0.88627451, 0.78431373,        0.80392157, 0.82745098, 0.90196078, 0.87843137, 0.91764706,        0.69019608, 0.7372549 , 0.98039216, 0.97254902, 0.91372549,        0.93333333, 0.84313725, 0.        ],
       [0.        , 0.22352941, 0.73333333, 0.81568627, 0.87843137,        0.86666667, 0.87843137, 0.81568627, 0.8       , 0.83921569,        0.81568627, 0.81960784, 0.78431373, 0.62352941, 0.96078431,        0.75686275, 0.80784314, 0.8745098 , 1.        , 1.        ,        0.86666667, 0.91764706, 0.86666667, 0.82745098, 0.8627451 ,        0.90980392, 0.96470588, 0.        ],
       [0.01176471, 0.79215686, 0.89411765, 0.87843137, 0.86666667,        0.82745098, 0.82745098, 0.83921569, 0.80392157, 0.80392157,        0.80392157, 0.8627451 , 0.94117647, 0.31372549, 0.58823529,        1.        , 0.89803922, 0.86666667, 0.7372549 , 0.60392157,        0.74901961, 0.82352941, 0.8       , 0.81960784, 0.87058824,        0.89411765, 0.88235294, 0.        ],
       [0.38431373, 0.91372549, 0.77647059, 0.82352941, 0.87058824,        0.89803922, 0.89803922, 0.91764706, 0.97647059, 0.8627451 ,        0.76078431, 0.84313725, 0.85098039, 0.94509804, 0.25490196,        0.28627451, 0.41568627, 0.45882353, 0.65882353, 0.85882353,        0.86666667, 0.84313725, 0.85098039, 0.8745098 , 0.8745098 ,        0.87843137, 0.89803922, 0.11372549],
       [0.29411765, 0.8       , 0.83137255, 0.8       , 0.75686275,        0.80392157, 0.82745098, 0.88235294, 0.84705882, 0.7254902 ,        0.77254902, 0.80784314, 0.77647059, 0.83529412, 0.94117647,        0.76470588, 0.89019608, 0.96078431, 0.9372549 , 0.8745098 ,        0.85490196, 0.83137255, 0.81960784, 0.87058824, 0.8627451 ,        0.86666667, 0.90196078, 0.2627451 ],
       [0.18823529, 0.79607843, 0.71764706, 0.76078431, 0.83529412,        0.77254902, 0.7254902 , 0.74509804, 0.76078431, 0.75294118,        0.79215686, 0.83921569, 0.85882353, 0.86666667, 0.8627451 ,        0.9254902 , 0.88235294, 0.84705882, 0.78039216, 0.80784314,        0.72941176, 0.70980392, 0.69411765, 0.6745098 , 0.70980392,        0.80392157, 0.80784314, 0.45098039],
       [0.        , 0.47843137, 0.85882353, 0.75686275, 0.70196078,        0.67058824, 0.71764706, 0.76862745, 0.8       , 0.82352941,        0.83529412, 0.81176471, 0.82745098, 0.82352941, 0.78431373,        0.76862745, 0.76078431, 0.74901961, 0.76470588, 0.74901961,        0.77647059, 0.75294118, 0.69019608, 0.61176471, 0.65490196,        0.69411765, 0.82352941, 0.36078431],
       [0.        , 0.        , 0.29019608, 0.74117647, 0.83137255,        0.74901961, 0.68627451, 0.6745098 , 0.68627451, 0.70980392,        0.7254902 , 0.7372549 , 0.74117647, 0.7372549 , 0.75686275,        0.77647059, 0.8       , 0.81960784, 0.82352941, 0.82352941,        0.82745098, 0.7372549 , 0.7372549 , 0.76078431, 0.75294118,        0.84705882, 0.66666667, 0.        ],
       [0.00784314, 0.        , 0.        , 0.        , 0.25882353,        0.78431373, 0.87058824, 0.92941176, 0.9372549 , 0.94901961,        0.96470588, 0.95294118, 0.95686275, 0.86666667, 0.8627451 ,        0.75686275, 0.74901961, 0.70196078, 0.71372549, 0.71372549,        0.70980392, 0.69019608, 0.65098039, 0.65882353, 0.38823529,        0.22745098, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.15686275, 0.23921569, 0.17254902,        0.28235294, 0.16078431, 0.1372549 , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        , 0.        , 0.        ,        0.        , 0.        , 0.        ]])

建立模型

model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation = tf.nn.relu),
        tf.keras.layers.Dense(10, activation = tf.nn.softmax)
    ]
)

这里我们建立了一个神经网络模型, 模型中有 3 层, 第一层:使用 Flatten 把2828的图片平铺成1维数组,也就是7681 的数据。 第二层:全连接层,有128个神经元。激活函数用的是 Relu。 第三层:全连接层,也是结果层,有10个神经元。这是因为我们需要进行的是一个多分类任务,而softmax就是专门用来进行多分类任务的。所以激活函数用的是softmax。

接下来编译模型并训练:

model.compile(
    optimizer = tf.optimizers.Adam(),
    loss="sparse_categorical_crossentropy",
    metrics = ['accuracy']
)
model.fit(training_images, training_labels, epochs = 5)

epochs = 5 是训练的轮数。

结果如下:

Epoch 1/5
1875/1875 [==============================] - 2s 692us/step - loss: 0.5020 - accuracy: 0.8235
Epoch 2/5
1875/1875 [==============================] - 1s 680us/step - loss: 0.3727 - accuracy: 0.8664
Epoch 3/5
1875/1875 [==============================] - 1s 735us/step - loss: 0.3389 - accuracy: 0.8765
Epoch 4/5
1875/1875 [==============================] - 1s 672us/step - loss: 0.3146 - accuracy: 0.8853
Epoch 5/5
1875/1875 [==============================] - 1s 686us/step - loss: 0.2960 - accuracy: 0.8910

训练完后,我们来试一试模型的准确率:

model.evaluate(test_images, test_labels)

结果大概是0.8668

313/313 [==============================] - 0s 544us/step - loss: 0.3702 - accuracy: 0.8668

如果想获得更精准的数据,可以适当增大epochs。进行更多轮的训练。

我们接下来看看测试集的第一张图片,然后验证算法是否准确:

classifications = model.predict(test_images)

print(classifications[0])

其结果是一个长度为10的数组,每个数字对应的是 0-9 类的概率。这里最大的是 index 为 9 的数字。所以预测结果应该是9.

[3.6468668e-06 5.4612696e-07 4.5575359e-08 2.6194398e-08 2.4610272e-06
 2.0921556e-02 4.4843373e-06 5.3708768e-01 9.6340482e-05 4.4188327e-01]

我们看一下其真实的label:

test_labels[0]

结果如下:

9

和我们预测的一致。

后记

到这里 一个简单的图片多分类任务就完成了。但是实际结果发生了过拟合。测试集的预测准确率是86%。训练集是89%。