手写数字分类
这是本书第三章的相关代码。首先,让我们导入一些公共模块,确保 MatplotLib 内联绘制图形,并准备一个函数来保存图形。我们还检查是否安装了 Python 3.5或更高版本(尽管 Python 2.x 可能可以工作,但不推荐使用,因此我们强烈建议您使用 Python 3) ,以及 Scikit-Learning ≥0.20。
一、导包
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)
# Is this notebook running on Colab or Kaggle?
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules
# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"
# Common imports
import numpy as np
import os
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)
def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
print("Saving figure", fig_id)
if tight_layout:
plt.tight_layout()
plt.savefig(path, format=fig_extension, dpi=resolution)
二、导入数据
- sklearn下载数据集在本地位置:D:\Anaconda\Lib\site-packages\sklearn\datasets\data
# 导入数据下载器
from scipy.io import loadmat
import pandas
from sklearn.datasets import fetch_openml
# 下载数据,Ubuntu难以下载,自行下载后放入当前目录
# mnist = fetch_openml('mnist_784', version=1, as_frame=False)
mnist = loadmat('./mnist-original.mat')
# mnist = pandas.read_csv('./mnist_784.csv')
# 显示数据集键值
mnist.keys()
dict_keys(['__header__', '__version__', '__globals__', 'mldata_descr_ordering', 'data', 'label'])
mnist['__version__']
'1.0'
# 查看
X, y = mnist["data"], mnist["label"]
X = np.transpose(X)
y = np.squeeze(np.transpose(y))
X.shape
(70000, 784)
y.shape
(70000,)
784指的是一张28*28的二值图的像素个数
三、数据预处理
# 打乱所有的数据集
#加载我所有的数据,这里想x_batch,Y_batch是list的格式,要注意
seed=64
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)#一定得重复在写一遍,和上面的seed要相同,不然y_batch和x_batch打乱顺序会不一样
np.random.shuffle(y)
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off")
save_fig("some_digit_plot")
plt.show()
Saving figure some_digit_plot
y[0]
6.0
# 封装绘制手写数字图像的函数
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
# 扩充版
def plot_digits(instances, images_per_row=10, **options):
size = 28
images_per_row = min(len(instances), images_per_row)
# This is equivalent to n_rows = ceil(len(instances) / images_per_row):
n_rows = (len(instances) - 1) // images_per_row + 1
# Append empty images to fill the end of the grid, if needed:
n_empty = n_rows * images_per_row - len(instances)
padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)
# Reshape the array so it's organized as a grid containing 28×28 images:
image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))
# Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),
# and axes 1 and 3 (horizontal axes). We first need to move the axes that we
# want to combine next to each other, using transpose(), and only then we
# can reshape:
big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,
images_per_row * size)
# Now that we have a big image, we just need to show it:
plt.imshow(big_image, cmap = mpl.cm.binary, **options)
plt.axis("off")
plt.figure(figsize=(9,9))
example_images = X[:100]
plot_digits(example_images, images_per_row=10)
save_fig("more_digits_plot")
plt.show()
Saving figure more_digits_plot
y[0]
6.0
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
四、训练一个二分类器
# 我们需要建立一个辨别是否为5的分类器
# 分离出测试集和训练集中的5标签
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
注意:在 Scikit-Learn 的未来版本中,某些超参数将具有不同的默认值,例如 max_iter
和 tol
。 为了面向未来,我们明确地将这些超参数设置为它们未来的默认值。 为简单起见,这本书没有显示。
# 导入随机梯度下降分类器
from sklearn.linear_model import SGDClassifier
# 实例化,传入迭代次数,学习率,和随机值
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)
# 训练
sgd_clf.fit(X_train, y_train_5)
# 对第一张手写数字图像:6进行预测是否为5,不是则对
sgd_clf.predict([some_digit])
array([False])
plot_digit(X[10])
sgd_clf.predict([X[10]])
array([False])
# 使用三折验证统计精度
from sklearn.model_selection import cross_val_score
# 需要传入:模型,训练集,折数,输出精度
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.96985, 0.9682 , 0.9658 ])
五、衡量表现
1. 使用交叉验证衡量精度
# 导入分层 K折验证器
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
# 三折交叉验证
skfolds = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
# 得到训练集和测试集三折的三个下标
for train_index, test_index in skfolds.split(X_train, y_train_5):
# 复制原来的模型
clone_clf = clone(sgd_clf)
X_train_folds = X_train[train_index]
y_train_folds = y_train_5[train_index]
X_test_fold = X_train[test_index]
y_test_fold = y_train_5[test_index]
# 训练
clone_clf.fit(X_train_folds, y_train_folds)
# 测试
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
0.9569
0.96375
0.9612
from sklearn.base import BaseEstimator
# 自定义一个评估器
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
pass
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
警告:此输出(以及此笔记本和其他笔记本中的许多其他输出)可能与书中的输出略有不同。别担心,没关系!有几个原因:
- 首先,Scikit-Learn 和其他库不断发展,算法得到了一些调整,这可能会改变你得到的确切结果。如果您使用最新的 Scikit-Learn 版本(通常,您确实应该使用),您可能不会使用我在写这本书或这本笔记本时使用的完全相同的版本,因此存在差异。我尽量使这本笔记本保持最新,但我无法更改您的书副本中页码上的数字。
- 其次,许多训练算法是随机的,这意味着它们依赖于随机性。原则上,可以通过设置生成伪随机数的种子来从随机数生成器中获得一致的输出(这就是为什么你会看到
random_state=42
或np.random.seed(42)
很经常)。但是,有时由于此处列出的其他因素,这还不够。 - 第三,如果训练算法跨多个线程运行(如在 C 中实现的某些算法)或跨多个进程(例如,使用
n_jobs
参数时),则不能始终保证操作运行的精确顺序,因此,确切的结果可能会略有不同。 - 最后,其他事情可能会阻止完美的再现性,例如 Python dicts 和集合,其顺序不能保证在会话中稳定,或者目录中的文件顺序也不能保证。
2. 混淆矩阵
# 导入交叉验证函数
from sklearn.model_selection import cross_val_predict
# 三折,使用随机梯度下降
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
# 导入混淆矩阵
from sklearn.metrics import confusion_matrix
# 计算混淆矩阵
confusion_matrix(y_train_5, y_train_pred)
array([[53628, 894],
[ 1029, 4449]])
可以看出,该模型将894个正样本(手写数字5)分类为其他数字,而将1029个非5图像分类为5.
# 查看标签的混淆矩阵
y_train_perfect_predictions = y_train_5 # pretend we reached perfection
confusion_matrix(y_train_5, y_train_perfect_predictions)
array([[54522, 0],
[ 0, 5478]])
3. 查准度和查全度
# 从metrics中导入精度和召回率
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred)
0.832678270634475
# 使用混淆矩阵计算精度(基于预测为正的所有样本来的,包括了:TP + FP):TP/ (TP + FP)
cm = confusion_matrix(y_train_5, y_train_pred)
# 计算精度
cm[1, 1] / (cm[0, 1] + cm[1, 1])
0.832678270634475
recall_score(y_train_5, y_train_pred)
0.812157721796276
# 使用混淆矩阵计算召回率(基于标签为正的所有样本来的,包括了:TP + FN):TP/ (TP + FN)
cm[1, 1] / (cm[1, 0] + cm[1, 1])
0.812157721796276
4. 权衡精度和召回率
# 得到决策边界函数值
y_scores = sgd_clf.decision_function([some_digit])
y_scores
array([-10289.64064706])
# 查看用于预测的数字图片
plot_digit(some_digit)
threshold = 0
# 将阈值设置为0
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
threshold = -11000
# -11000,负样本就变为正样本
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([ True])
# 得到交叉验证决策函数值
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
method="decision_function")
# 从指标中导入PR曲线
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.legend(loc="center right", fontsize=16) # Not shown in the book
plt.xlabel("Threshold", fontsize=16) # Not shown
plt.grid(True) # Not shown
plt.axis([-50000, 50000, 0, 1]) # Not shown
# 选取精度>=0.9的阈值
recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
# np.argmax返回精度大于0.9的首个下标值
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
plt.figure(figsize=(8, 4)) # Not shown
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:") # Not shown
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:") # Not shown
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")# Not shown
plt.plot([threshold_90_precision], [0.9], "ro") # Not shown
plt.plot([threshold_90_precision], [recall_90_precision], "ro") # Not shown
save_fig("precision_recall_vs_threshold_plot") # Not shown
plt.show()
Saving figure precision_recall_vs_threshold_plot
注意
:可以看出,加大阈值可能会使查准度下降,但是减少阈值一定会使查全率升高
(y_train_pred == (y_scores > 0)).all()
True
def plot_precision_vs_recall(precisions, recalls):
plt.plot(recalls, precisions, "b-", linewidth=2)
plt.xlabel("Recall", fontsize=16)
plt.ylabel("Precision", fontsize=16)
plt.axis([0, 1, 0, 1])
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
save_fig("precision_vs_recall_plot")
plt.show()
Saving figure precision_vs_recall_plot
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
# 选取精度大于0.9的决策边界
threshold_90_precision
1803.6738636953733
# 利用90精度的阈值来做出决策
y_train_pred_90 = (y_scores >= threshold_90_precision)
# 查看精度
precision_score(y_train_5, y_train_pred_90)
0.9000673703121491
# 召回率变低
recall_score(y_train_5, y_train_pred_90)
0.7316538882803943
5. ROC曲线
# 从指标中导入ROC曲线
from sklearn.metrics import roc_curve
# 得到真正率(召回率),假正率(1 - 真负率),真正真负都是基于标签而言的
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--') # dashed diagonal
plt.axis([0, 1, 0, 1]) # Not shown in the book
plt.xlabel('False Positive Rate (Fall-Out)', fontsize=16) # Not shown
plt.ylabel('True Positive Rate (Recall)', fontsize=16) # Not shown
plt.grid(True) # Not shown
plt.figure(figsize=(8, 6)) # Not shown
plot_roc_curve(fpr, tpr)
fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)] # Not shown
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:") # Not shown
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:") # Not shown
plt.plot([fpr_90], [recall_90_precision], "ro") # Not shown
save_fig("roc_curve_plot") # Not shown
plt.show()
Saving figure roc_curve_plot
# 从指标模块中,导入ROC分数
from sklearn.metrics import roc_auc_score
# 也就是计算曲线和x轴形成的面积
roc_auc_score(y_train_5, y_scores)
0.9671233027792312
- 注意:此处的评估器个数被设置为100,是因为高版本的Sklearn中将其设置为默认值。
# 导入随机森林分类器
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
# 注意此处的随机森林分类器返回的是各样本属于对应类别的均值概率
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,
method="predict_proba")
y_probas_forest
array([[1. , 0. ],
[0.97, 0.03],
[0.97, 0.03],
...,
[0.06, 0.94],
[1. , 0. ],
[0.98, 0.02]])
可以看出第一列为负类概率,第二列为正类概率
y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)
# 计算召回率
recall_for_forest = tpr_forest[np.argmax(fpr_forest >= fpr_90)]
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.plot([fpr_90, fpr_90], [0., recall_for_forest], "r:")
plt.plot([fpr_90], [recall_for_forest], "ro")
plt.grid(True)
plt.legend(loc="lower right", fontsize=16)
save_fig("roc_curve_comparison_plot")
plt.show()
Saving figure roc_curve_comparison_plot
从PRC曲线看来,曲线越向左上角偏移,性能就越好。所以随机森林分类器的性能要优于SGD分类器
roc_auc_score(y_train_5, y_scores_forest)
0.9984469476493366
y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)
precision_score(y_train_5, y_train_pred_forest)
0.991280880215902
recall_score(y_train_5, y_train_pred_forest)
0.8716684921504199
我们在选取对应的指标衡量模型性能时也需注意,当正类较少或者我们更关心假正类时,我们会选择PR曲线,反之就选择ROC曲线
六、多类分类
# 导入SVM分类器
from sklearn.svm import SVC
svm_clf = SVC(gamma="auto", random_state=42)
# 使用前1000个样本进行训练,由于此处用于进行多分类(一个样本对应一个分类,但是分类数超过2)
# 所以我们直接传入原始标签值:【0,9】
svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5
svm_clf.predict([some_digit])
array([6.])
此处的分类器采用的是SVM,内部会构造多个一对一的分类器:1 V 2, 1 V 3,2 V 3
some_digit_scores = svm_clf.decision_function([some_digit])
some_digit_scores
array([[ 2.78447015, 4.93931166, 8.18095238, 6.02785743, 7.1051902 ,
0.77096814, 9.29616663, -0.23717199, 3.82067593, 1.78447015]])
# 查看概率最大的首个下标
np.argmax(some_digit_scores)
6
svm_clf.classes_
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
svm_clf.classes_[6]
6.0
# 导入一对多分类器
from sklearn.multiclass import OneVsRestClassifier
ovr_clf = OneVsRestClassifier(SVC(gamma="auto", random_state=42))
ovr_clf.fit(X_train[:1000], y_train[:1000])
ovr_clf.predict([some_digit])
array([6.])
# 只有10个分类器
# 依次实现,鉴别非0与0,非1与1,...,非9与9
len(ovr_clf.estimators_)
10
# 直接将所有数据集投入SGD分类器,内部会自动实现10个一对多分类器
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])
array([6.])
sgd_clf.decision_function([some_digit])
array([[-47338.45231781, -21271.3416702 , -9564.98434185,
-16208.54374001, -8629.77193181, -13563.74950208,
4630.22629936, -17866.72455436, -9697.88455665,
-13667.07951762]])
# 由于我们需要训练10个二分类,这段代码的执行时间可能是之前训练辨5分类器的10倍
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
array([0.87555, 0.8752 , 0.87265])
# 从预处理导入简单缩放器
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
/home/zhoumingyao/anaconda/yes/envs/tf_gpu/lib/python3.10/site-packages/sklearn/linear_model/_stochastic_gradient.py:705: ConvergenceWarning: Maximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.
warnings.warn(
array([0.8966 , 0.90275, 0.9041 ])
可以看到,进行简单缩放之后,精度有些许提升。
七、误差分析
# 对SGD10分类器进行混淆矩阵分析
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
/home/zhoumingyao/anaconda/yes/envs/tf_gpu/lib/python3.10/site-packages/sklearn/linear_model/_stochastic_gradient.py:705: ConvergenceWarning: Maximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.
warnings.warn(
array([[5553, 0, 14, 6, 8, 38, 37, 7, 196, 1],
[ 0, 6423, 41, 23, 2, 44, 6, 4, 173, 10],
[ 20, 28, 5268, 83, 78, 23, 71, 45, 381, 10],
[ 24, 14, 115, 5256, 1, 202, 26, 47, 375, 64],
[ 11, 12, 40, 9, 5230, 7, 40, 21, 298, 175],
[ 35, 17, 21, 163, 53, 4530, 83, 21, 483, 72],
[ 28, 18, 41, 2, 46, 87, 5514, 7, 139, 0],
[ 23, 13, 56, 23, 42, 10, 5, 5673, 181, 213],
[ 16, 52, 41, 91, 4, 124, 32, 11, 5451, 42],
[ 22, 19, 30, 60, 129, 32, 1, 170, 343, 5171]])
# 若我们的Sklearn版本大于0.22,那么我们可以直接调用:sklearn.metrics.plot_confusion_matrix()
def plot_confusion_matrix(matrix):
"""If you prefer color and a colorbar"""
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
cax = ax.matshow(matrix)
fig.colorbar(cax)
plt.matshow(conf_mx, cmap=plt.cm.gray)
save_fig("confusion_matrix_plot", tight_layout=False)
plt.show()
Saving figure confusion_matrix_plot
# 对混淆矩阵按行归一化,也就是按照标签个数进行放缩
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
save_fig("confusion_matrix_errors_plot", tight_layout=False)
plt.show()
Saving figure confusion_matrix_errors_plot
上图越白就代表混淆得越狠,可以看出数字8和其他数字都混淆得很厉害,特别是和数字5,我们可以采取一些形态学操作来突出各数字的特征,比如说利用闭操作来突出8的两个环。
cl_a, cl_b = 5, 8
# 获取训练集中的图像5,8
# 获取被预测正确的数字图像5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
# 获取将5预测为8的图像
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
# 获取将8预测为5的图像
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
# 获取被预测正确的数字图像8
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
save_fig("error_analysis_digits_plot")
plt.show()
Saving figure error_analysis_digits_plot
八、多标签分类
之前我们的分类器对于单个样本,都只返回一个确定的标签,但是以人脸识别为例,假如我们的分类器学会了识别迈克,杰克和玛丽,给定图片中只含有玛丽,那么我们的标签就为:[0, 0 , 1],也就是单个样本可以拥有多个标签(标签数提前定好)
# 导入K邻近分类器
from sklearn.neighbors import KNeighborsClassifier
# 分离出用于训练判断对应数字是否大于等于7的标签
y_train_large = (y_train >= 7)
# 分离出用于训练判断对应数字是否为奇数的标签
y_train_odd = (y_train % 2 == 1)
# 基于同一训练集,但是标签不同进行训练
y_multilabel = np.c_[y_train_large, y_train_odd]
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
# 对图片6进行测试
knn_clf.predict([some_digit])
array([[False, False]])
from sklearn.metrics import f1_score
# 三折交叉验证得到预测值
y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)
# 计算f1系数
f1_score(y_multilabel, y_train_knn_pred, average="macro")
0.9776279826458136
九、多输出分类
前面讨论的所有分类都是单输出分类,也就是对于一张图片只有一个输出(之前的多标签也只是基于自定义标签个数来分别输出的),这里的多输出指的是输入一张图片,我们会输出去噪后的所有像素点的像素值(784个)。
# 随机生成噪声,此处也可以传入随即因子便于复现
noise = np.random.randint(0, 100, (len(X_train), 784))
# 为训练集中的图片加上噪声用于训练
X_train_mod = X_train + noise
# 为测试及也加上噪声
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
# 将原训练图像和测试图像作为训练标签和测试标签
y_train_mod = X_train
y_test_mod = X_test
some_index = 0
plt.subplot(121); plot_digit(X_test_mod[some_index])
plt.subplot(122); plot_digit(y_test_mod[some_index])
save_fig("noisy_digit_example_plot")
plt.show()
Saving figure noisy_digit_example_plot
knn_clf.fit(X_train_mod, y_train_mod)
clean_digit = knn_clf.predict([X_test_mod[some_index]])
plot_digit(clean_digit)
save_fig("cleaned_digit_example_plot")
Saving figure cleaned_digit_example_plot