机器学习-KNN实现数据集DBRHD手写识别

736 阅读4分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

1. 项目简介:

本文将使用sklearnsklearn来训练一个K最近邻(k-Nearest Neighbor, KNN) 分类器,用于识别数据集DBRHD的手写数字。


2. 概念阐述:

KNN(K Near Neighbor):k个最近的邻居,即每个样本都可以用它最接近的k个邻居来代表,KNN是一种分类算法,该算法可以应用在文本识别、分类以及图像识别等领域。

KNN最邻近分类算法的实现原理(近朱者赤近墨者黑):

  • 为了判断未知样本的类别,以所有已知类别的样本作为参照,计算未知样本与所有已知样本的距离。
  • 从中选取与未知样本距离最近的K个已知样本
  • 根据少数服从多数的投票法则(majority-voting),将未知样本与K个最邻近样本中所属类别占比较多的归为一类。
  • 例如,如下图所示,先判断绿色方框属于哪一类。则k=3是绿色方框被分类为蓝色三角,K=5时,绿色方框被分类为红色圆圈。

356d5c3be1354ba8baca9ccaa61d166d.png

DBRHD数据集:每一个图片是由0或1组成的32*32的文本矩阵,如下图所示。

sklearn库:sklearn库共分为6大部分,分别用于完成分类任务、回归任务、聚类任务、降维任务、模型选择以及数据的预处理。提供一批统一化的机器学习方法功能接口。

因此KNN的输入为图片矩阵展开的一个1024维的向量。

KNN是一种懒惰学习法,没有学习过程,只在预测时去查找最近邻的点,数据集的输入就是构建KNN分类器的过程。

注: 文末附有本文数据集文件,分享给大家一起学习


3. 程序代码:

import numpy as np     #导入numpy工具包
from os import listdir   #使用listdir模块,用于访问本地文件
from sklearn import neighbors
import time

start = time.perf_counter()
## 定义img2vector函数,将加载的32*32的图片矩阵展开成一列向量
def img2vector(fileName):
    retMat = np.zeros([1024],int) #定义返回的矩阵,大小为1*1024
    fr = open(fileName)           #打开包含32*32大小的数字文件
    lines = fr.readlines()        #读取文件的所有行      #readlines()从文件中一行一行地读数据,返回一个列表
    for i in range(32):           #遍历文件所有行
        for j in range(32):       #并将0-1数字存放在retMat中
            retMat[i*32+j] = lines[i][j]
    return retMat

##定义加载训练数据的函数readDataSet
def readDataSet(path):
    fileList = listdir(path)    #获取文件夹下的所有文件
    numFiles = len(fileList)    #统计需要读取的文件的数目
    dataSet = np.zeros([numFiles,1024],int)    #用于存放所有的数字文件
    hwLabels = np.zeros([numFiles])#用于存放对应的标签(与神经网络的不同)
    for i in range(numFiles):      #遍历所有的文件
        filePath = fileList[i]     #获取文件名称/路径
        digit = int(filePath.split('_')[0])   #通过文件名获取标签
        hwLabels[i] = digit        #直接存放数字,并非one-hot向量
        dataSet[i] = img2vector(path +'/'+filePath)    #读取文件内容
    return dataSet,hwLabels

#调用以上定义的两个函数加载数据,将训练的图片存放在train_dataSet中,对应的标签(文件名中的有用信息标签)存放在train_hwLabels中
train_dataSet, train_hwLabels = readDataSet('trainingDigits')

##构建KNN分类器
knn = neighbors.KNeighborsClassifier(algorithm='kd_tree', n_neighbors=3)
knn.fit(train_dataSet, train_hwLabels)

#加载测试集
dataSet,hwLabels = readDataSet('testDigits')

##使用构建好的KNN分类器对测试集进行预测,并计算预测的错误率
res = knn.predict(dataSet)  #对测试集进行预测
error_num = np.sum(res != hwLabels) #统计分类错误的数目
num = len(dataSet)          #测试集的数目
print("Total num:",num," Wrong num:", error_num,"  WrongRate:",error_num / float(num))
end = time.perf_counter()
t = end-start
print("所用时间为:", t)

4. 运行结果:

经过构建好的KNN分类器对测试集的预测,结果如下图所示:

KNN的准确率远高于MLP分类器,这是由于MLP在小数据集上容易过拟合。

本案例中使用的数据集文件分享给大家:

链接:百度网盘 请输入提取码 
提取码:ZiDa

将其中的训练手写图像文件夹命名为trainingDigits

将其中的测试手写图像文件夹命名为testDigits

保存到相应的工作区,运行本文代码即可