机器学习中的性能度量

129 阅读9分钟

写在前面

在机器学习当中对于模型的性能度量有许多方式,但是在这篇文章中主要介绍准确率,召回率,精准率、ROC曲线、P-R曲线以及如何通过混淆矩阵来获取到真正例(TP)、假正例(FP)、真反例(TN)、假反例(FN)

1.模型训练常见的两种情况

在进行机器学习的模型训练的过程中,如果对于模型的选择有误 ---> 即数据集的各种特征模型并不能完全学习或者学习的有偏差,亦或是对数据集的处理不恰当,可能就会出现两种极端的情况: 过拟合欠拟合

因此,在正式介绍机器学习的性能度量之前先介绍一下过拟合欠拟合

1.1 欠拟合

1.1.1 什么是欠拟合

欠拟合是指不能很好的从训练数据中,学习到有用的数据模式,从而针对训练数据和待预测的数据,均不能获得很好的预测效果。如果使用的训练样本过少,较容易获得欠拟合的训练模型。

1.1.2 解决方法

我们一般采用以下方法来减小过拟合的影响:

  1. 提取更多有效特征。
  2. 使用更强大的模型,如深度学习模型。
  3. 增加更多训练数据。
  4. 增大迭代次数,优化算法运行得更久一点。
  5. 降低正则化强度,允许模型更好地拟合训练数据。
  6. 检查数据是否有问题,如有离群点等异常值。

1.2 过拟合

1.2.1 什么是欠拟合

过拟合是与欠拟合相对的概念,欠拟合是对于数据集的特征学习不足,而过拟合是对数据集的特征学习"太好"的表现,将数据集中的噪声点也一并学习了,导致在训练集上表现非常好,但是到测试集上预测结果非常差。过拟合可能给出不准确的预测,并且无法对所有类型的新数据表现良好。

1.2.2 解决方法

正则化

正则化方法包括L0正则、L1正则和L2正则。 L0范数是指向量中非0的元素的个数。L1范数是指向量中各个元素绝对值之和,也叫“稀疏规则算子”(Lasso regularization)。两者都可以实现稀疏性。 L2范数是指向量各元素的平方和然后求平方根。可以使得W的每个元素都很小,都接近于0,但与L1范数不同,它不会让它等于0,而是接近于0。L2正则项起到使得参数w变小加剧的效果。

剪枝

剪枝是决策树中一种控制过拟合的方法,预剪枝通过在训练过程中控制树深、叶子节点数、叶子节点中样本的个数等来控制树的复杂度。后剪枝则是在训练好树模型之后,采用交叉验证的方式进行剪枝以找到最优的树模型。

提前终止迭代

主要是用在神经网络中的,在神经网络的训练过程中我们会初始化一组较小的权值参数,此时模型的拟合能力较弱,通过迭代训练来提高模型的拟合能力,随着迭代次数的增大,部分的权值也会不断的增大。如果我们提前终止迭代可以有效的控制权值参数的大小,从而降低模型的复杂度。

原文链接:blog.csdn.net/zhuanzhe117…

2.常用的模型性能度量

在前面介绍了在机器学习中容易出现的两种情况:过拟合欠拟合,那么如何避免或者提前发现这两种情况,那这就需要机器学习的性能度量了,性能度量能够比较客观地反应模型的性能,帮助我们更好的去调整模型或是数据集的处理。

2.1 准确率

在正式介绍准确率之前先阐述几个数值的概念,这在后续的性能度量指标中也会用到。

真正例(True Positive,TP):一个正例被正确预测为正例

真反例(True Negative,TN):一个反例被正确预测为反例

假正例(False Positive,FP):一个反例被错误预测为正例

假反例(False Negative,FN):一个正例被错误预测为反例

上面几个数值在下面的性能度量指标中也会用到

准确率是所有样本中预测正确的样本占比,其公式如下:

Accuracy=TP+TNTP+TN+FP+FN Accuracy = \frac{TP+TN}{TP+TN+FP+FN}

准确率有一个明显的弊端问题,就是在数据的类别不均衡,特别是有极端的数据存在的情况下,准确率这个评价指标并不能客观准确评价模型优劣。

2.2 混淆矩阵

混淆矩阵是机器学习中总结分类模型预测结果的情形分析表,以矩阵形式将数据集中的记录按照真实的类别与分类模型预测的类别判断两个标准进行汇总 混淆矩阵要表达的含义:

  • 混淆矩阵的每一列代表了预测类别,每一列的总数表示预测为该类别的数据的数目;
  • 每一行代表了数据的真实归属类别,每一行的数据总数表示该类别的数据实例的数目;每一列中的数值表示真实数据被预测为该类的数目。

image.png

上述图片可用于理解二分类混淆矩阵,通过添加更多的行和列,就可以得到多分类混淆矩阵。

后续的TP、TN、FP、FN等值都是由混淆矩阵获得的

2.2.1 二分类混淆矩阵

以分类模型中最简单的二分类为例,对于这种问题,我们的模型最终需要判断样本的结果是0还是1,或者说是positive还是negative。

我们通过样本的采集,能够直接知道真实情况下,哪些数据结果是positive,哪些结果是negative。同时,我们通过用样本数据跑出分类型模型的结果,也可以知道模型认为这些数据哪些是positive,哪些是negative。 将这四个指标一起呈现在表格中,就能得到如下这样一个矩阵,我们称它为混淆矩阵

image.png

2.2.2 多分类混淆矩阵

在二分类混淆矩阵的基础上,分别在横向和纵向增加行和列,大致如下:

image.png

2.3 精准率

精准率(Precision)又称查准率,它是针对预测结果而言的,它的含义是在所有被预测为正的样本中实际为正的样本的概率,即在预测为正样本的结果中,有多少把握可以预测正确,其公式如下:

Precision=TPTP+TNPrecision = \frac{TP}{TP+TN}

2.4 召回率

召回率(Recall)又称查全率,它是针对原样本而言的,它的含义是在实际为正的样本中被预测为正样本的概率,其公式如下:

Recall=TPTP+FNRecall = \frac{TP}{TP+FN}

Precision和Recall是一对此消彼长的度量。例如在推荐系统中,想让推送的内容尽可能用户全都感兴趣,那只能推送把握高的内容,这样就漏掉了一些用户感兴趣的内容,Recall就低了;如果想让用户感兴趣的内容都被推送,那只有将所有内容都推送上,宁可错杀一千,不可放过一个,这样Precision就很低了。

引用文章:【最全的】分类算法的性能度量指标 - 知乎 (zhihu.com))

2.5 P-R曲线和ROC曲线

2.5.1 P-R曲线

P-R曲线(Precision Recall Curve )表现的是Precision和Recall之间的关系。

P-R曲线定义如下:根据学习器的预测结果(一般为一个实值或概率)对测试样本进行排序,将最可能是“正例”的样本排在前面,最不可能是“正例”的排在后面,按此顺序逐个把样本作为“正例”进行预测,每次计算出当前的P值和R值

image.png

当P-R曲线越靠近右上方时,表明模型性能越好。在对不同模型进行比较时,一般有以下两种情况:

  • 若一个模型的P-R曲线被另一个模型的P-R曲线完全包住,则说明后者的性能优于前者,如上图中A代表的模型要优于C代表的模型。
  • 若模型的P-R曲线发生了交叉,则谁的曲线下的面积大,谁的性能更优。但一般来说,曲线下面积很难估算,所以衍生出了“平衡点”(Break-Event Point,简称BEP),如上图所示红色点,即当P=R时的取值,平衡点的取值越高,性能更优。

2.5.2 ROC曲线

ROC曲线全称为受试者工作特征曲线(Receiver Operating Characteristic Curve)。ROC是一张图上的一条线(如下图所示),越靠近左上角的ROC曲线,模型的准确度越高,模型越理想

image.png

图ROC曲线中,横轴是假阳率(False positive rate ,简称FPR),定义为 在所有真实的负样本中,被模型错误的判断为正例的比例 ,计算公式如下

FPR=FPFP+TNFPR = \frac{FP}{FP+TN}

纵轴是真阳率(True Positive Rate,简称TPR),定义为 在所有真实的正样本中,被模型正确的判断为正例的比例,其实就是召回率,计算公式如下:

TPR=FPTP+FNTPR = \frac{FP}{TP+FN}

3.代码实现

准备工作

# 计算TP,FP,TN,FN
def perf_measure(y_true,y_pred):
    TP, FP, TN, FN = 0, 0, 0, 0
    matx = np.zeros((3,3))
    for i in range(len(y_true)):
        matx[y_true[i]][y_pred[i]] += 1
    TP = np.diag(matx)
    FP = matx.sum(axis=0) - TP
    FN = matx.sum(axis=1) - TP
    TN = matx.sum() - (TP + FP + FN)
    FPR = FP/(FP+TN)
    TPR = TP/(TP+FN)
    print(FP,TP,FN,FPR,TPR)
    return TP/(FP+TP),TP/(FN+TP),FPR,TPR

y_true和y_pred分别代表真实值和预测值

P-R曲线绘制

# 图像绘制
def drawpr(x,y,x_table,y_table,title):
    plt.rc("font",family="FangSong")
    plt.plot(x,y,marker='o')
    plt.xlabel(x_table,fontsize=20)
    plt.ylabel(y_table,fontsize=20)
    plt.title(title,fontsize=20)
    plt.show() 
# 调用图像绘制函数,绘制P-R曲线
drawpr(re_array,pr_array,"Recall","Precision","PR曲线")

re_array和pr_array分别代表 召回率和精准率

ROC曲线绘制

# 图像绘制
def drawpr(x,y,x_table,y_table,title):
    plt.rc("font",family="FangSong")
    plt.plot(x,y,marker='o')
    plt.xlabel(x_table,fontsize=20)
    plt.ylabel(y_table,fontsize=20)
    plt.title(title,fontsize=20)
    plt.show() 
# 调用图像绘制函数,绘制ROC曲线
drawpr(fpr_array,tpr_array,"FPR","TPR","ROC曲线")  

fpr_array和tpr_array分别代表 假正例率和真正例率