混淆矩阵计算实现

179 阅读1分钟

fast_hist函数,并使用sklearn.metrics.confusion_matrix验证。

from sklearn.metrics import confusion_matrix
import numpy as np

def fast_hist(a, b, n):
    k = (a >= 0) & (a < n)
    count = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)  # 核心代码
    # print("count", count)
    confusionMatrix = count.reshape(n, n)
    return confusionMatrix


a = [[0, 1, 0], [2, 1, 0], [2, 2, 1]]  # 语义分割的二维矩阵
b = [[0, 2, 0], [2, 1, 0], [1, 2, 1]]
a = np.array(a).reshape(-1)
b = np.array(b).reshape(-1)

print("a", a)
print("b", b)

print("confusionMatrix", fast_hist(a, b, 3))
print("sklearn.metrics.confusion_matrix", confusion_matrix(a, b))

输出如下:

image.png