multi-class分类模型评估指标的定义、原理及其Python实现 (3)

189 阅读4分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 25 天,点击查看活动详情

4. Recall值

真实标签为正值的被预测正确的概率。

sklearn的函数文档:sklearn.metrics.recall_score — scikit-learn 1.1.1 documentation

4.1 Micro-R

计算所有真实标签为正值的被预测正确的概率。

使用Python的原生函数实现:

import json

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

lp_list=[x.count(1) for x in label]
lp=sum(lp_list)
tp_list=[[label[y][x]==1 and prediction[y][x]==1 for x in range(len(prediction[0]))].count(True) for y in range(len(prediction))]
tp=sum(tp_list)
print(tp/lp)

使用sklearn实现:

import json
from sklearn.metrics import recall_score

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

print(recall_score(label,prediction, average='micro'))

输出:0.4684014869888476

4.2 Macro-R

计算每一类标签对应的真实正值中预测正确的比例,然后将所有标签的R值求平均。 如果某一类标签没有真实正值,sklearn的默认处理方式是将R值置0并报警告信息,本文在原生函数实现中也采用了这一方法。(这种情况有毛病,不应该出现这种情况的)

使用Python原生函数实现:

import json
from statistics import mean

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

p_list=[0 for _ in range(len(label[0]))]
for label_index in range(len(label[0])):
    l=[x[label_index] for x in label]
    p=[x[label_index] for x in prediction]
    if l.count(1)==0:
        print('索引为'+str(label_index)+'的标签无真实正值!')
    else:
        p_list[label_index]=[l[x]==1 and p[x]==1 for x in range(len(l))].count(1)/l.count(1)
print(mean(p_list))

输出:

0.21012970014737198

使用sklearn实现:

import json
from sklearn.metrics import recall_score

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

print(recall_score(label,prediction, average='macro'))

输出:

0.210129700147372

这个差异明显地是由于精度问题。

5. F1得分

F1=2(precisionrecall)/(precision+recall)F1 = 2 * (precision * recall) / (precision + recall)

sklearn的函数文档:sklearn.metrics.f1_score — scikit-learn 1.1.1 documentation

5.1 Micro-F1

micro F1的得分分别就是算micro的P和R,用原生Python直接把对应的结果算出来再计算F1值即可。以下仅介绍使用sklearn的实现方式。

import json
from sklearn.metrics import f1_score

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

print(f1_score(label,prediction, average='micro'))

输出:0.5974710221285564 2*(0.8247272727272728*0.4684014869888476)/(0.8247272727272728+0.4684014869888476)

5.2 Macro-F1

macro-F1是计算每一类的F1值,然后求平均。 sklearn对除以0的默认处理方式是将结果置0并报警告信息,本文在原生函数实现中也采用了这一方法。

使用Python原生函数实现:

import json
from statistics import mean

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

p_list=[0 for _ in range(len(label[0]))]
r_list=[0 for _ in range(len(label[0]))]
for label_index in range(len(label[0])):
    l=[x[label_index] for x in label]
    p=[x[label_index] for x in prediction]
    if p.count(1)==0:
        print('索引为'+str(label_index)+'的标签无正预测值!')
    else:
        p_list[label_index]=[l[x]==1 and p[x]==1 for x in range(len(l))].count(1)/p.count(1)
    
    if l.count(1)==0:
        print('索引为'+str(label_index)+'的标签无真实正值!')
    else:
        r_list[label_index]=[l[x]==1 and p[x]==1 for x in range(len(l))].count(1)/l.count(1)

f_list=[(0 if p_list[x]+r_list[x]==0 else 2*p_list[x]*r_list[x]/(p_list[x]+r_list[x])) for x in range(len(label[0]))]
print(mean(f_list))

输出:

索引为26的标签无正预测值!
索引为28的标签无正预测值!
索引为30的标签无正预测值!
索引为32的标签无正预测值!
索引为35的标签无正预测值!
索引为36的标签无正预测值!
索引为37的标签无正预测值!
索引为41的标签无正预测值!
索引为42的标签无正预测值!
索引为44的标签无正预测值!
索引为45的标签无正预测值!
索引为46的标签无正预测值!
索引为47的标签无正预测值!
索引为48的标签无正预测值!
索引为49的标签无正预测值!
索引为50的标签无正预测值!
索引为51的标签无正预测值!
索引为52的标签无正预测值!
索引为53的标签无正预测值!
0.26380909234445127

使用sklearn的实现方式:

import json
from sklearn.metrics import f1_score

label=json.load(open('data/cls/AAPD/label.json'))
prediction=json.load(open('data/cls/AAPD/prediction.json'))

print(f1_score(label,prediction, average='macro'))

输出:0.26380909234445127

6. 其他

  1. 本文使用的示例是multi-class multi-label任务,如果是multi-class one-label任务的话,还会出现另一种特性,就是accuracy==micro F1。可以参考这两个网站:accuracy f1 为什么多分类 等于micro - CSDNmachine learning - Is F1 micro the same as Accuracy? - Stack Overflow