predict = output.argmax(dim = 1)
confusion_matrix =torch.zeros(2,2)
for t, p in zip(predict.view(-1), target.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
a_p =(confusion_matrix.diag() / confusion_matrix.sum(1))[0]
b_p = (confusion_matrix.diag() / confusion_matrix.sum(1))[1]
a_r =(confusion_matrix.diag() / confusion_matrix.sum(0))[0]
b_r = (confusion_matrix.diag() / confusion_matrix.sum(0))[1]
本文介绍如何使用PyTorch计算预测结果与真实标签的混淆矩阵,并从中提取准确率和召回率等性能指标。通过实例代码展示了predict输出的最大概率类别与目标标签对比的过程,进而构建混淆矩阵并计算各类别上的准确率和召回率。
1万+





