import numpy as np
from prettytable import PrettyTable
class ConfusionMatrix(object):
"""
matrix format:
tp fp
fn tn
"""
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds:list, label:list):
table1 = PrettyTable()
self.matrix = np.zeros((self.num_classes, self.num_classes))
for p, t in zip(preds, label):
self.matrix[t, p] += 1
table1.field_names = [""] + self.labels
for i, row in enumerate(self.matrix):
row_data = [self.labels[i]] + list(map(int, row)) # Convert to integer for better display
table1.add_row(row_data)
print("Confusion Matrix:")
print(table1)
def summary(self, whether_ci=False):
self.sensitivity = []
self.precision = []
self.specificity = []
self.accuracy = []
self.npv = []
ci_result = {}
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
self.acc = round(sum_TP / np.sum(self.matrix), 3)
print("the model accuracy is ", self.acc)
# precision, recall, specificity
table = PrettyTable()
table.field_names = ["", "Accuracy", "Precision", "Recall", "Specificity", "Npv"]
for i in range(self.num_classes):
TP = self.matrix[i, i] # 对角线
FN = np.sum(self.matrix[i, :]) - TP # 第i行的所有列
FP = np.sum(self.matrix[:, i]) - TP # 所有行的第i列
TN = np.sum(self.matrix) - TP - FP - FN
self.Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
self.Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
self.Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
self.Npv = round(TN / (TN + FN), 3) if TN + FN != 0 else 0.
table.add_row([self.labels[i], self.acc, self.Precision, self.Recall, self.Specificity, self.Npv])
# print(self.Precision, self.Recall, self.Specificity)
self.sensitivity.append(self.Recall)
self.precision.append(self.Precision)
self.specificity.append(self.Specificity)
self.npv.append(self.Npv)
avg_precision = round(np.mean(self.precision), 3)
avg_recall = round(np.mean(self.sensitivity), 3)
avg_specificity = round(np.mean(self.specificity), 3)
avg_npv = round(np.mean(self.npv), 3)
table.add_row(["Average", self.acc, avg_precision, avg_recall, avg_specificity, avg_npv])
print(table)
1.首先初始化类:ac = ConfusionMatrix(numclass(分类的个数),labels(标签数)
)
示例输入:
ac = ConfusionMatrix(3,['类别1','类别2','类别3')
2.调用方式:
#调用方式
ac.update([预测值],[实际值])
3.输出样式:

9902

被折叠的 条评论
为什么被折叠?



