本文已参与[新人创作礼]活动,一起开启掘金创作之路
为了学好机器学习,自己尝试写一下经典案例——手写字识别
在sklearn库中有专门的手写字库,用起来很方便
from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
d_data = load_digits()
# #了解d_data
# print(d_data)
# #了解它的所有键
# print(d_data.keys())
# #了解它的data键的值
# print(d_data['data'])
# #了解它的data键值的第一组数据
# print(d_data['data'][0])
# #了解它的data键值的第一组数据的维度
# print(d_data['data'][0].ndim)
from sklearn.model_selection import train_test_split #导入数据集拆分工具
X_train, X_test, y_train, y_test = train_test_split(d_data['data'], d_data['target'],
random_state=0)
# #查看训练数据集和测试数据集的状况
# print('X_train shape:{}'.format(X_train.shape))
# print('X_test shape:{}'.format(X_test.shape))
# print('y_train shape:{}'.format(y_train.shape))
# print('y_test shape:{}'.format(y_test.shape))
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
#预测
y_ = knn.predict(X_test)
#自己从结果上看一下预测准确性
print(y_test == y_) #bool型的数组
print((y_test == y_).mean()) #mean是求平均的
#KNN自带的预测准确性方法
print(knn.score(X_test, y_test))
#可视化处理
from PIL import Image
import numpy as np
def img_show(img):
pil_img=Image.fromarray(np.uint8(img))
pil_img.show()
img = X_train[0]
label = y_train[0]
print(label)
print(img.ndim)
print(img.shape)
img = img.reshape(8, 8)
img_show(img)
运行结果
sklearn自带的图片像素太小了,看不清,这个问题放在后面解决吧。