【主线1】MNIST手写数字识别--朴素贝叶斯

0 阅读2分钟

mnist手写数字识别

1、数据加载

踩坑1: MultinomialNB (and GaussianNB)不能处理负数,而normalization会按分布生成大量负数导致训练模型时会报错,原始数据本身给出的y值范围是[0,255]完全符合输入标准。直接加载即可。后续用于神经网络(对输入尺度、距离敏感的)才需要,而朴素贝叶斯基于概率模型,不需要。 .data 属性是原始未经过 transform 的 Tensor,值范围是 [0, 255] 整数。所以 Normalize 根本没有被用到

对比Normalize 后,值范围大约是 [-0.42, 2.82],均值为 0,标准差为 1。

踩坑2:训练集拆分,原始数据分为训练集和验证集,而测试集需要在训练集中划分而不是把验证集划分。😢

踩坑3:算法要求输入维度特征是相互独立的标量,Scikit-learn 的 API 约定要求输入必须是 (样本数, 特征数) 的二维矩阵,所以需要使用rehsape降维处理,PS:注意view和reshape降维的区别。

2、得分计算

准确率= 实际值与预测值相等的数量/总量。(求平均值) (y′==y).mean()

实际为正实际为负
预测为正TPFP
预测为负 FNTN

一个具体的例子

假设测试集有 100 个样本,其中 20 个是(正类),80 个是非猫(负类)。

模型预测结果:

预测为猫的有 15 个,其中 12 个真的是猫(TP=12),3 个不是猫(FP=3)

预测为非猫的有 85 个,其中 8 个其实是猫(FN=8),77 个确实不是猫(TN=77)

预测值 vs Recall vs Precision vs Accuracy

指标公式回答的问题本例数值
预测值model.predict(某张图)这张图是不是猫?猫 / 非猫
Accuracy(TP+TN) / 总数总体猜对了多少?(12+77)/100 = 0.89
PrecisionTP / (TP+FP)预测为猫的里面,多少真的是猫?12/15 = 0.80
RecallTP / (TP+FN)真正的猫里面,找出了多少?12/20 = 0.60
F1 Score2 × P × R / (P + R)精确率和召回率的调和平均2×0.8×0.6/1.4 = 0.69

Recall 特别低(比如数字 5 的 Recall 只有 0.65),说明模型经常漏掉这个数字。

Precision 特别低,说明模型经常把其他数字误判成这个数字。