手写数字识别模型

742 阅读4分钟

一、案例背景

图像识别是机器学习领域一个非常重要的应用场景,像现在非常火的人脸识别就是基于机器学习的图像识别相关算法的。这里先介绍一个较为简单的图像识别案例——手写数字识别模型,其原理与人脸识别有共通之处。

二、原理

手写数字识别,或者说图像识别的本质就是把如下图所示的一张图片转换成计算机能够处理的数字形式。

image.png

1、图像二值化

如下图所示是将图片格式的数字4转换成由0和1组成的“新的数字4”。这是一个32×32的矩阵,数字1代表有颜色的地方,数字0代表无颜色的地方,这样就完成了手写数字识别的第一步也是最关键的一步:将图片转换为计算机能识别的内容——数字0和1。这个步骤又称为图像二值化。 image.png

2、二维数组转换为一维数组

经过图像二值化处理获得的32×32的0-1矩阵相当于一个二维数组,为了方便进行机器学习建模,还需要对这个二维数组进行简单的处理:在第1行数字之后依次拼接第2~32行的数字,得到一个1×1024的一维数组,如下所示。

image.png

3、距离计算

手写数字图片处理后形成的1×1024的二维数组可以看成一个行向量,两张图片对应的行向量间的欧氏距离可以反映两张图片的相似度。因此,我们可以利用K近邻算法模型计算新样本与原始训练集中各个样本的欧氏距离,取新样本的K个近邻点,并以大多数近邻点所在的分类作为新样本的分类。

例如,有一个样本为手写数字4,将它转换为如下所示的1×1024的行向量。

将另一个手写数字x转换为如下所示的1×1024的行向量,假设其只有中间一个数字不同。

那么手写数字x与样本手写数字4的距离的计算过程如下。

image.png

根据类似的原理,对于一个新样本,我们可以计算它与每个不同样本数字之间的距离,再根据与其距离最近的K个近邻点判别其属于哪个分类,即哪个数字。

三、代码实现

现在以如下图所示的数据集为例讲解手写数字识别的代码实现。该数据集为1934个处理好的手写数字0~9的1×1024矩阵,其中每一行为一个手写数字,第1列“对应数字”为该手写数字,其余每一列为该手写数字对应的1×1024矩阵中的一个数字

image.png

1、读取数据

下载 手写字体识别.xlsx

首先通过pandas库读取数据,代码如下。

import pandas as pd
df = pd.read_excel("手写字体识别.xlsx")
df.head()

image.png

2、提取特征变量和目标变量

x = df.drop(columns="对应数字")
y = df["对应数字"]

第1行代码用drop()函数删除“对应数字”列,将剩下的数据作为特征变量赋给x;

第2行代码提取“对应数字”列作为目标变量赋给y。

所有样本的1×1024矩阵都由0和1构成,故无须做标准化处理。如果在其他场景中出现数量级相差较大的特征变量,则需要对数据进行标准化处理,代码如下。

from sklearn.preprocessing import StandardScaler 
x = StandardScaler().fit_transform(x)

3、划分训练集和测试集

from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=123)

4、模型搭建

from sklearn.neighbors import KNeighborsClassifier as KNN
knn = KNN(n_neighbors=5)
knn.fit(x_train,y_train)

5、模型预测与评估

y_pred = knn.predict(x_test)

通过打印输出y_pred[0:100]查看前100个预测结果,如下图所示。

image.png

通过和之前章节类似的代码,我们可以将预测值和实际值进行对比。

a = pd.DataFrame()
a["预测值"] = list(y_pred)
a["实际值"] = list(y_test)
a.head()

此时a的前5组数据见下表,可以看到,前5组数据的预测准确度为100%

image.png

通过如下代码可以查看对整个测试集的预测准确度。

knn.score(x_test,y_test)

输出如下:

0.9767441860465116

前面搭建模型时设置n_neighbors参数为5,如果想换成其他数值进行参数调优,可以模仿交叉验证和网格搜索来完成,代码如下。

from sklearn.model_selection import GridSearchCV
# 参数候选值
params = {'n_neighbors':[1,2,3,4,5,6,7,8]}  
knn = KNN()
# 5折交叉验证
grid_serach = GridSearchCV(knn,params,cv = 5)
# 以准确度为基础进行网格搜索,寻找参数最优值
grid_serach.fit(x_train,y_train)
# 获取参数最优值
grid_serach.best_params_['n_neighbors']

输出

3

四、总结

总体来说,K近邻算法是一种非常经典的机器学习算法,其原理清晰简单,容易理解,不过也有一些缺点,例如,样本量较大时计算量大,拟合速度较慢。本章学习的手写数字识别模型其实也是图像识别模型的一个简单应用,是为之后学习更精彩的图像识别模型打好基础。