菜菜的机器学习sklearn【全85集】Python进阶------夏の哉---97it.------top/---14209/
Sklearn 中的 “Hello World”:手写数字识别初体验
在机器学习领域,“手写数字识别” 如同编程世界的 “Hello World”—— 它是一个入门级任务,却完整覆盖了机器学习的核心流程:数据加载、探索、预处理、模型训练、评估与预测。Scikit-learn(Sklearn)作为 Python 最流行的机器学习库,内置了经典的手写数字数据集(MNIST 的简化版),并提供了简洁的 API,让初学者能在 30 行代码内完成从 “加载数据” 到 “识别数字” 的全流程。本文将带你从零开始,用 Sklearn 实现手写数字识别,揭开机器学习的神秘面纱。
一、数据集初探:认识手写数字数据
Sklearn 的datasets模块内置了load_digits数据集,包含 8×8 像素的手写数字图片及对应的标签(0-9),共 1797 个样本。这个数据集虽不及 MNIST 的 28×28 像素精细,但足以用于入门实践。
1. 加载并可视化数据
首先用 Sklearn 加载数据集,并通过可视化了解数据结构:
# 导入库
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
# 加载数据集
digits = load_digits()
X = digits.data # 特征数据:每个样本是64维向量(8×8像素展平)
y = digits.target # 标签:0-9的数字
# 可视化前4个样本
fig, axes = plt.subplots(1, 4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target[:4]):
ax.imshow(image, cmap=plt.cm.gray_r) # 灰度显示
ax.set_title(f"数字: {label}")
ax.axis("off") # 隐藏坐标轴
plt.show()
运行代码后,会显示 4 张手写数字图片(如 0、1、2、3),每张图片的像素为 8×8,取值范围 0-16(0 为白色,16 为黑色)。特征数据X将 8×8 的矩阵展平为 64 维向量(8×8=64),方便模型处理。
2. 数据基本信息
通过简单代码了解数据集规模:
print(f"样本数量: {X.shape[0]}") # 输出1797
print(f"特征维度: {X.shape[1]}") # 输出64(8×8)
print(f"标签类别: {len(set(y))}") # 输出10(0-9)
这些信息告诉我们:需要用 64 个像素特征来预测 10 个可能的数字类别,属于典型的多分类任务。
二、数据预处理:为模型准备 “干净” 的输入
机器学习模型对输入数据敏感,预处理的目标是让数据格式符合模型要求,并减少噪声干扰。对于手写数字数据集,核心预处理步骤是划分训练集与测试集。
1. 划分数据集
我们需要用一部分数据训练模型(学习像素与数字的对应关系),另一部分数据测试模型性能(验证学习效果)。Sklearn 的train_test_split工具可快速实现:
from sklearn.model_selection import train_test_split
# 按7:3划分训练集和测试集,random_state确保结果可复现
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
print(f"训练集样本数: {X_train.shape[0]}") # 输出1257
print(f"测试集样本数: {X_test.shape[0]}") # 输出540
划分的核心原则是:
- 训练集足够大(通常 60%-80%),让模型充分学习规律;
- 测试集具有代表性,能客观反映模型在新数据上的表现。
2. 特征标准化(可选)
手写数字的像素值范围是 0-16,数值差异较小,部分模型(如 SVM、KNN)对特征尺度敏感,需标准化为均值 0、方差 1 的分布:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train) # 用训练集拟合并转换
X_test_scaled = scaler.transform(X_test) # 用训练集的规则转换测试集
注意:测试集的标准化必须使用训练集的均值和方差,避免 “数据泄露”(用测试集信息影响模型)。
三、模型训练:用 SVM 实现数字识别
Sklearn 提供了多种分类算法,对于手写数字识别,支持向量机(SVM) 是经典选择 —— 它在高维空间(如 64 维像素)中表现优异,能有效划分不同数字的特征边界。
1. 初始化并训练模型
from sklearn.svm import SVC
# 初始化SVM模型,使用默认参数(适合入门)
model = SVC()
# 训练模型:用训练集的特征和标签学习规律
model.fit(X_train_scaled, y_train)
SVM 的核心思想是 “找到最佳分隔超平面”:在 64 维空间中,找到一个平面将不同数字的样本分开,且使平面到最近样本的距离(margin)最大。Sklearn 的SVC类已封装了这一复杂过程,我们只需调用fit方法即可。
2. 模型预测
训练完成后,用测试集评估模型的预测能力:
# 用模型预测测试集的数字
y_pred = model.predict(X_test_scaled)
# 查看前10个预测结果与真实标签
print("预测结果:", y_pred[:10])
print("真实标签:", y_test[:10])
输出可能类似:
预测结果: [6 9 3 7 2 1 5 2 5 0]
真实标签: [6 9 3 7 2 1 5 2 5 0]
前 10 个预测全部正确,初步说明模型有效,但需更全面的评估。
四、模型评估:量化识别性能
仅通过少数样本无法判断模型优劣,需要用量化指标和可视化工具全面评估。
1. 核心评估指标
Sklearn 的metrics模块提供了分类任务的常用指标:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# 准确率:正确预测的样本占比
accuracy = accuracy_score(y_test, y_pred)
print(f"准确率: {accuracy:.2f}") # 通常在0.97-0.99之间
# 详细分类报告:包含精确率、召回率、F1分数
print(classification_report(y_test, y_pred))
- 准确率(Accuracy) :整体正确率,适合样本均衡的场景(如每个数字的样本数相近);
- 精确率(Precision) :预测为某数字的样本中,实际为该数字的比例(如预测为 “8” 的样本中,98% 确实是 “8”);
- 召回率(Recall) :实际为某数字的样本中,被正确预测的比例(如所有 “5” 的样本中,97% 被正确识别)。
对于手写数字识别,优秀模型的准确率通常能达到 97% 以上。
2. 混淆矩阵:可视化错误模式
混淆矩阵展示 “真实标签” 与 “预测标签” 的对应关系,能直观发现模型容易混淆的数字(如 “3” 和 “8”、“5” 和 “9”):
import seaborn as sns
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 可视化混淆矩阵
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=digits.target_names,
yticklabels=digits.target_names)
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.title("混淆矩阵")
plt.show()
矩阵对角线的数值表示正确预测的样本数,非对角线数值表示错误(如第 3 行第 8 列的数值表示 “真实为 3 却被预测为 8” 的样本数)。通过混淆矩阵,我们能针对性优化模型(如增加易混淆数字的训练样本)。
五、进阶尝试:用不同模型对比效果
除了 SVM,Sklearn 还提供了多种算法,可快速替换模型进行对比:
1. 决策树
from sklearn.tree import DecisionTreeClassifier
tree_model = DecisionTreeClassifier(max_depth=10)
tree_model.fit(X_train_scaled, y_train)
print("决策树准确率:", accuracy_score(y_test, tree_model.predict(X_test_scaled))) # 约0.85-0.90
2. K 近邻(KNN)
from sklearn.neighbors import KNeighborsClassifier
knn_model = KNeighborsClassifier(n_neighbors=5)
knn_model.fit(X_train_scaled, y_train)
print("KNN准确率:", accuracy_score(y_test, knn_model.predict(X_test_scaled))) # 约0.96-0.98
对比发现,SVM 和 KNN 的准确率高于决策树,这说明不同算法对同一任务的适应性存在差异。在实际应用中,我们需要根据数据特点选择合适的模型。
六、总结与扩展:从识别数字到理解机器学习
手写数字识别的流程虽简单,却包含了机器学习的核心逻辑:
- 数据驱动:模型的能力来自对数据的学习,而非人工编写的规则;
- 泛化能力:训练的目标是让模型在未见过的测试集上表现良好,而非死记硬背训练数据;
- 迭代优化:通过评估指标发现问题(如某数字识别率低),调整模型参数或数据预处理步骤,逐步提升性能。
扩展方向:
- 参数调优:用GridSearchCV优化 SVM 的C和gamma参数,进一步提升准确率;
- 特征工程:尝试提取更有效的特征(如像素梯度、轮廓特征),而非直接使用原始像素;
- 挑战更大数据集:使用完整 MNIST 数据集(28×28 像素,60000 个样本),体验深度学习模型(如 CNN)的优势。
Sklearn 的 “手写数字识别” 就像一把钥匙,它打开了机器学习的大门。当你看到模型能从模糊的像素点中准确识别出数字时,或许就能理解为什么机器学习能在图像识别、语音处理等领域掀起革命 —— 它让计算机具备了 “从数据中学习” 的能力,而这正是智能的起点。