代码有些参考了其他博客
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
def plot_confusion_matrix(cm, classes, normalize=False, cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
Input
- cm : 计算出的混淆矩阵的值
- classes : 混淆矩阵中每一行每一列对应的列
- normalize : True:显示百分比, False:显示个数
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title = 'Normalized confusion matrix'
cb_title = 'Number of recordings (normalized)'
else:
title =</

本文介绍了一个用于绘制混淆矩阵的Python函数,该函数可以根据实际标签和预测标签生成混淆矩阵,并支持是否进行归一化的选项。此外,还展示了如何使用该函数从PyTorch模型输出的数据中生成混淆矩阵。
最低0.47元/天 解锁文章
5864

被折叠的 条评论
为什么被折叠?



