在 Python 中,绘制混淆矩阵通常会用到 `seaborn` 和 `matplotlib` 库,同时会借助 `sklearn.metrics` 中的 `confusion_matrix` 函数来计算混淆矩阵。以下是几种常见的绘制方法:
#### 方法一:使用 `seaborn` 的 `heatmap` 函数
```python
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
sns.set()
f, ax = plt.subplots()
y_true = [0, 0, 1, 2, 1, 2, 0, 2, 2, 0, 1, 1]
y_pred = [1, 0, 1, 2, 1, 0, 0, 2, 2, 0, 1, 1]
C2 = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
print(C2)
sns.heatmap(C2, annot=True, ax=ax)
ax.set_title('confusion matrix')
ax.set_xlabel('predict')
ax.set_ylabel('true')
plt.show()
```
此方法先使用 `confusion_matrix` 函数计算混淆矩阵,再利用 `seaborn` 的 `heatmap` 函数将其绘制成热力图,最后通过 `matplotlib` 显示图形[^1]。
#### 方法二:自定义标签绘制混淆矩阵
```python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 假设这是你的混淆矩阵数据
confusion_matrix = np.array([
[680, 18, 133, 95, 3],
[22, 406, 25, 93, 11],
[99, 24, 348, 217, 7],
[45, 43, 83, 1149, 78],
[5, 7, 8, 127, 406]
])
# 自定义的标签名称
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
# 设置绘图
plt.figure(figsize=(10, 7))
ax = sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', cbar_kws={'label': 'Scale'}, xticklabels=labels, yticklabels=labels)
# 添加标题和轴标签
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 优化标签显示
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
# 显示图形
plt.show()
```
该方法同样使用 `seaborn` 的 `heatmap` 函数绘制混淆矩阵,不同之处在于可以自定义标签名称,并且能对图形的大小、颜色、标签显示等进行更多的设置[^2]。
#### 方法三:使用 `matplotlib` 的 `matshow` 函数
```python
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
def cm_plot(original_label, predict_label, pic=None):
cm = confusion_matrix(original_label, predict_label)
plt.figure()
plt.matshow(cm, cmap=plt.cm.Blues)
plt.colorbar()
for x in range(len(cm)):
for y in range(len(cm)):
plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.title('confusion matrix')
if pic is not None:
plt.savefig(str(pic) + '.jpg')
plt.show()
```
此方法使用 `matplotlib` 的 `matshow` 函数绘制混淆矩阵,同时通过 `annotate` 函数在每个单元格中添加具体的数值,还可以将图形保存为图片[^3]。