目录
混淆矩阵
原理
在机器学习中尤其是统计分类中,混淆矩阵(confusion matrix),也被称为错误矩阵(error matrix)。
矩阵的每一列表达了分类器对于样本的类别预测,二矩阵的每一行则表达了版本所属的真实类别
之所以叫做’混淆矩阵‘,是因为能够很容易的看到机器学习有没有将样本的类别给混淆了。
代码
import numpy as np
def confusion_matrix(labels, predictions, class_num = 45, normalization=True):
"""
:param labels: list, for example [1, 2, 3, 4, 5, 6, 7]
:param predictions: list, length equals the length of labels, for example [1, 2, 3, 4, 6, 6, 7]
:param class_num: total classes
:param normalization: nomalization at [0, 1]
:return: matrix, shape is [class_num, class_num]
"""
assert len(labels) == len(predictions)
_matrix = np.zeros((class_num, class_num))
_total = [0 for i in range(class_num)]
for label, prediction in zip(labels, predictions):
_matrix[label, prediction] += 1
_total += 1
if normalization:
_matrix = np.transpose(np.transpose(_matrix) * 1. / _total)
return _matrix