Python中生成并绘制混淆矩阵(confusion matrix)

博客内容为转载,但未提供具体转载信息。
Python 中,绘制混淆矩阵通常会用到 `seaborn` 和 `matplotlib` 库,同时会借助 `sklearn.metrics` 中的 `confusion_matrix` 函数来计算混淆矩阵。以下是几种常见的绘制方法: #### 方法一:使用 `seaborn` 的 `heatmap` 函数 ```python import seaborn as sns from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt sns.set() f, ax = plt.subplots() y_true = [0, 0, 1, 2, 1, 2, 0, 2, 2, 0, 1, 1] y_pred = [1, 0, 1, 2, 1, 0, 0, 2, 2, 0, 1, 1] C2 = confusion_matrix(y_true, y_pred, labels=[0, 1, 2]) print(C2) sns.heatmap(C2, annot=True, ax=ax) ax.set_title('confusion matrix') ax.set_xlabel('predict') ax.set_ylabel('true') plt.show() ``` 此方法先使用 `confusion_matrix` 函数计算混淆矩阵,再利用 `seaborn` 的 `heatmap` 函数将其绘制成热力图,最后通过 `matplotlib` 显示图形[^1]。 #### 方法二:自定义标签绘制混淆矩阵 ```python import numpy as np import matplotlib.pyplot as plt import seaborn as sns # 假设这是你的混淆矩阵数据 confusion_matrix = np.array([ [680, 18, 133, 95, 3], [22, 406, 25, 93, 11], [99, 24, 348, 217, 7], [45, 43, 83, 1149, 78], [5, 7, 8, 127, 406] ]) # 自定义的标签名称 labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E'] # 设置绘图 plt.figure(figsize=(10, 7)) ax = sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', cbar_kws={'label': 'Scale'}, xticklabels=labels, yticklabels=labels) # 添加标题和轴标签 plt.title('Confusion Matrix') plt.xlabel('Predicted Label') plt.ylabel('True Label') # 优化标签显示 plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) # 显示图形 plt.show() ``` 该方法同样使用 `seaborn` 的 `heatmap` 函数绘制混淆矩阵,不同之处在于可以自定义标签名称,且能对图形的大小、颜色、标签显示等进行更多的设置[^2]。 #### 方法三:使用 `matplotlib` 的 `matshow` 函数 ```python import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix import numpy as np def cm_plot(original_label, predict_label, pic=None): cm = confusion_matrix(original_label, predict_label) plt.figure() plt.matshow(cm, cmap=plt.cm.Blues) plt.colorbar() for x in range(len(cm)): for y in range(len(cm)): plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center') plt.ylabel('True label') plt.xlabel('Predicted label') plt.title('confusion matrix') if pic is not None: plt.savefig(str(pic) + '.jpg') plt.show() ``` 此方法使用 `matplotlib` 的 `matshow` 函数绘制混淆矩阵,同时通过 `annotate` 函数在每个单元格中添加具体的数值,还可以将图形保存为图片[^3]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值