原文:
towardsdatascience.com/heatmap-for-confusion-matrix-in-python-20a9fc689665
图片由作者提供
简介
混淆矩阵是展示机器学习模型犯的错误类型的一种便捷方式。它是一个 N 行 N 列的数字网格,其中 [n, m] 单元格中的值表示被标注为第 n 类并被识别为第 m 类的示例数量。在本教程中,我将重点介绍如何创建混淆矩阵和热图。颜色调色板将用于显示不同组的大小,使其容易注意到组大小之间的相似性或显著差异。当你处理大量类别时,这种可视化方法非常有用。
这里是混淆矩阵元素的一个视觉解释。
图片由作者提供
请记住,用于演示混淆矩阵的数据是人工的,并不代表任何真实的分类模型。
现在,我将逐步解释如何使用 Python 模块生成这样的混淆矩阵。
Python 基础
要使用热图创建混淆矩阵,你需要三个模块:
pip install scikit-learn, seaborn, pandas
假设你拥有两个预测列表和真实标签列表,你需要执行以下操作:
-
计算混淆矩阵 –
confusion_matrix -
将变量转换为数据框 –
pd.DataFrame -
创建热图图表 –
sn.heatmap -
最后,将图表保存到文件中 –
cfm_plot.figure.savefig
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sn
if __name__ == '__main__':
predictions = ["None", "Dog", "Cat", ...]
true_labels = ["None", "Dog", "Dog", ...]
cm = confusion_matrix(true_labels, predictions)
df_cfm = pd.DataFrame(cm)
cfm_plot = sn.heatmap(df_cfm)
cfm_plot.figure.savefig("data/confusion_matrix_v1.png")
这里是输出结果:
图片由作者提供
输出完全不起眼。默认情况下,许多有用的信息和自定义功能被禁用或不适合我们的数据。让我们改进这个图表。
标签
我们需要创建一个显示标签的标签列表。我们可以使用预测和真实标签的信息(第一行)来做这件事。为了提高可读性并保持运行之间的相同顺序,我们将多数类(None)移动到第一个元素,并对剩余的标签进行排序(第二行)。如果不这样做,标签的顺序可能会在每次运行代码时不同。
接下来,我们在 confusion_matrix 方法中添加 labels=label_names,并在数据框的构造函数中添加 index=label_names, columns=label_names。
label_names = list(set([] + predictions + true_labels))
label_names = ["None"] + sorted([a for a in label_names if a != "None"])
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
cfm_plot = sn.heatmap(df_cfm)
图片由作者提供
我们可以看到一个与标签相关的问题。Y 轴上的标签部分被截断。为了解决这个问题,我们可以使用 figsize 增加绘图画布。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm)
图片由作者提供
值
在下一步中,我们将为每个单元格显示一个值。这是一种方便的方法来观察每个类别对的精确错误数量。我们将使用heatmap方法的annot参数来显示这些值。它需要一个与我们的混淆矩阵相同维度的数据框。因此,我们可以再次传递相同的数据框,即df_cfm。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm, annot=df_cfm)
图片由作者提供
我们可以看到确切的值,但这个图有两个问题。第一个是显示大数的方式很奇怪,另一个是 0 的数量很多,使得图表难以阅读。
为了解决显示数字的问题,我们将更改默认的字符串格式化代码,将其从.2g更改为空字符串,使用fmt参数。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm, annot=df_cfm, fmt="")
图片由作者提供
为了隐藏 0,我们将复制数据框并将每个 0 替换为空字符串。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="")
cfm_plot.figure.savefig("data/confusion_matrix_v6.png")
图片由作者提供
尺度和颜色
热图的思路是使用颜色来直观地展示数值。在我们的例子中,我们可以看到只有一个不同的值,即 None 类中真正阳性的数量,而其他类别(狗和猪)的数值,而其他类别看起来相同。问题是我们的数值范围,从 0 到 1413,其中大多数数值接近 0。为了使数值更引人注目,我们可以将颜色的尺度从线性改为对数。这可以通过设置norm参数为LogNorm()来实现。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm())
图片由作者提供
使用颜色的对数尺度,图表看起来要好得多。移除空单元格后,分析非空单元格更容易,因为干扰更少。在某些情况下,可能很难跟随行和列。为了解决这个问题,我们可以使用linewidth和linecolor参数添加垂直和水平线。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm(),
linewidths=0.5, linecolor="grey")
图片由作者提供
最后一步是选择首选的颜色调色板。Seaborn 有几个现成的调色板,在此处展示:seaborn.pydata.org/tutorial/color_palettes.html。要更改颜色调色板,将调色板的名称提供给cmap参数。这里是一个使用crest调色板的示例。
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm(),
linewidths=0.5, linecolor="grey", cmap="crest")
图片由作者提供:调色板徽标
这里是一些其他调色板的示例:
图片由作者提供:调色板 viridis
图片由作者提供:调色板 magma
图片由作者提供:调色板 rocket_r
结论
在颜色和格式上玩耍可能看起来是浪费时间,因为数字才是最重要的。然而,合适的图表可能会显著提高我们数据的可读性和可访问性,尤其是在我们向一个对我们所熟悉的数据不那么熟悉的客户展示时。花一些额外的时间来弄清楚是否有更好的方式来展示原始数据和从分析中获得的见解是值得的。
参考文献
[2] scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
704

被折叠的 条评论
为什么被折叠?



