【从0开始学AI】用scikit-learn的分类算法实现手写数字的识别

551 阅读8分钟

我保证,即使完全没有学过机器学习的人,也可以通过这篇文章对手写字数识别这个问题,自己上手操作一次。写出你自己的AI。

    

    手写数字的识别是一个很容易入门又特别好理解的机器学习问题。我们知道如果是一个印刷体的数字可能用一个程序比对就可以识别了,但是手写数字会因人而异,每个

数字最终都会有细微差别,因此是一个很好的机器学习问题。

    (什么是好的机器学习问题呢?一个问题无法用一个固定的规则写出程序,但是可以方便的采集到数据和标记)

     我们看一看上面的这个手写数字的图片,如果用 4 X 4的方格来存储一个手写数字的话,那么每个空格上面的笔画就可以作为识别的基础单位,换一个说法,如果某些

方格里正好有特定的笔画,就更有可能是某个数字。

    在scikit-learn里,有一个现成的数据集(手写数字的数据集),它是用一个 8 X 8的数组来存储一个数字的笔画轨迹,显然 8X8 比 4 X4要精确一些。

Digits dataset(手写数字的数据集)

    scikit-learn例子的原文是这么说的:The digits dataset consists of 8x8 pixel images of digits.

    The images attribute of the dataset stores 8x8 arrays of grayscale values for each image

   换句话说,数据集里的数据是一种images图片,由8X8像素组成的图片,也就是一个8X8的数组,数组里的值是一个灰度值,当然我不清楚这里的灰度值到底是什么值,

但我知道,如果说这个灰度值的取值范围是0-255,那么0肯定代表着这个像素是白色的(或者说没有书写痕迹),255就代表这个像素是最终的笔墨。

   可以看一看scikit-learn例子里是怎么把这个数据集加载出来,并且显示几个数字的图像,这样更加直观一点:


# License: BSD 3 clause

# Standard scientific Python imports

import matplotlib.pyplot as plt


# Import datasets, classifiers and performance metrics

from sklearn import datasets, svm, metrics

from sklearn.model_selection import train_test_split


     这个部分代码就是加载一些必要的库,比如plt,datasets 等,没必要深究,重点看下面的数据集加载代码:

digits = datasets.load_digits()

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))

for ax, image, label in zip(axes, digits.images, digits.target):

    ax.set_axis_off()

    ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")

    ax.set_title("Training: %i" % label)

 

如果对Python或机器学习不是很熟悉的朋友,千万不要紧张,其实这几句话还是相对比较简单的,我把几个关键的地方给你指明出来,就非常好理解了:

  • 1, datasets.load_digits()

            这句非常简单,load_digits就是加载手写数字数据集,所谓数据集其实就是一个包含手写数字8X8像素图和正确的结果的数据集合。比如,数字0的像素图

         (0)

       你看我故意在这个像素图的右边写了一个0,这个0是怎么来的呢,不是机器学习识别出来的,是人看完这个手写的数字之后标记出来的,你可以想见,

数据集里都是这样的正确的 像素图和数字 的结对。

  • 2, plt.subplots(nrows=1, ncols=4

        plt是一个画图的库,我们这段代码是把数据集的前4个数据显示出来,因此创建一个 1 * 4的显示区域,也就是rows=1,cols=4, 一行四列。

  • 3, ax.imshow(image

        注意,这里的ax会随着循环移到每一个plt的subplot上,也即是依次移到1*4的显示区域上,然后imshow把image图片显示出来,这里的image就是

   8X8 的像素图

  • 4, ax.set_title("Training: %i" % label)

      注意这里的label就是前面我提到的,数据集里 和 像素图1:1对应的正确数字。

把这一小段代码调试执行出来的效果如下图所示:

其实就是显示数据集的前4个数据(图像和正确的数字),不过看完这个效果图,我们对像素图有了更清晰的认知,也就为我们接下来做机器学习提供了很好的基础

Classification(机器学习)

我们没有能力去改变数据集,它的特性就是上面展示的那样了。现在我们要思考的主要问题是,在这个数据集的基础上怎么做机器学习?

本质上来说,我们就是要找到一种算法,假设写作 y = f(x)

这里的x是什么呢?y是什么呢?在机器学习里,x就是那个我们人能识别的原始信息,y就是要识别的结果,f就是机器学习的模型,在手写数字是别中,

x就是上面的像素图,y就是像素图真正对应的数字 (0,1,2,3…),我的任务是要找到合适的f

进一步的思考,我们可以把0-9这10个数字,看做10个类别,如果这么思考的话,f其实就是要把原始的像素图分到这个10个类别下。

换一句AI领域的话来说,也就是机器学习里的 分类算法(Classification)

#注意

    不过这还不是问题的关键,问题的关键在于x,一个8X8的数组不是一个标准的分类算法参数,在机器学习里,算法的输入x,必须是一个

包含n的属性列的数据,因此最重要的问题是如何把 8X8的像素数组进行转换,这里我们可以看一看例子中的原文怎么描述的:

To apply a classifier on this data, we need to flatten the images, turning each 2-D array of grayscale values from shape (8, 8)

 into shape (64,). Subsequently, the entire dataset will be of shape (n_samples, n_features), 

flatten我理解意思是进行一个压扁,也就是把二维的8*8压扁放平成 64个像素,这样的话,每个像素的值,就可以看做这个数据的一个属性,也就是

机器学习里的features,换一种说法,我用64个像素位置上的像素值来表示一个手写数字的一个特点。这样处理的话,最终数据就变成

(n_samples, n_features) 的数组,这里的n_samples就是数据集的数据个数

写出代码如下:

# flatten the images

n_samples = len(digits.images)

data = digits.images.reshape((n_samples, -1))

  

# Create a classifier: a support vector classifier

clf = svm.SVC(gamma=0.001)

# Split data into 50% train and 50% test subsets

X_train, X_test, y_train, y_test = train_test_split(

    data, digits.target, test_size=0.5, shuffle=False

)

# Learn the digits on the train subset

clf.fit(X_train, y_train)


# Predict the value of the digit on the test subset

predicted = clf.predict(X_test)

    特别注意,这段代码包含刚才说的数据处理和分类算法的训练,已经最终的预测,我分别解释一下:

  • data = digits.images.reshape((n_samples, -1))

            就是刚才说的数据压扁过程,把88 压成641

  • clf = svm.SVC(gamma=0.001)

            分类算法,选 svm(别问为什么选svm,先看看svm的效果,不行还可以换的)

  • X_train, X_test, y_train, y_test = train_test_split(

        这里的train_test_split 就是机器学习里的 训练集和测试集拆分,训练集是为了训练取得分类算法模型最终参数的,测试集是测试模型参数的准确率用的(看后面就知道)

  • clf.fit(X_train, y_train)

        训练(clf就是前面的分类算法svm)

  • predicted = clf.predict(X_test)

        测试(或者叫预测),X_test就是刚才的测试集,predicted可以说就是前面提到的y,也就是我们要让机器帮我们做分类得到的数字0-9, clf.predict就是我们求得的f

这就结束了吗

    需要特别注意,机器学习到这一步不算结束,接下来最重要的,是分析predicted这个预测结果,为什么呢?

    我们注意到,X_test其实是对应了一个Y_test的,这个Y_test就是完全正确的结果,那么用机器学习分类得到的predicted和Y_test会一样吗?

    如果100%一样,反而不是好的AI(这就是机器学习里的过犹不及)

    其实任何问题,包括手写识别问题,人都不可能100%识别准备,何况机器呢。

   因此我们需要分析这里的predicted,也就是例子中的如下代码:


    f"Classification report for classifier {clf}:\n"

    f"{metrics.classification_report(y_test, predicted)}\n"

)

    classification_report 其实就是把我们刚才提到的 预测集合和测试集合正确结果之间的差异比对 更加细化了,结果得到当前模型的准确率为97%左右。

因此针对这个问题,其实只是开始,不是结束,还记得之前的分类算法吗?用的是svm,我们可以问自己,97%就是最好结果了吗?可以换一个分类算法吗?

只要你开始问这个问题,意味着你真正开始入门AI了

    现在你就可以选用其他的分类器和调试分类器的参数,继续来训练你的机器学习模型。

#附

机器学习两个重要概念,也就是上面的结果报告里的  precision 和 recall

precision 就是精确度,recall就是召回率,注意这两个都是判断结果准确度的指标,不过算法不一样。

具体可以参看另外一个专题文章:zhuanlan.zhihu.com/p/369936908…

#参考资料

参考:scikit-learn.org/stable/auto…