引言
混淆矩阵是分类任务常用的一种评估方法。对角线元素表示预测标签等于真实标签的点数,而非对角线元素则是分类器未正确标记的点的数量。 混淆矩阵的对角线值越高越好,表明有许多正确的预测。1
尤其是在类别数量不平衡的情况下,相比accuracy,混淆矩阵(confusion matrix)对哪个类被错误分类具有更直观的解释。
在平时做简单的数据实验时,可以仅用from sklearn.metrics import plot_confusion_matrix或者seaborn对混淆矩阵进行可视化。但是在深度学习训练模型的过程中,在tensorboard中可视化混淆矩阵会更方便结果记录和对照。
混淆矩阵
在tensorboard中的可视化效果:

代码实现
代码参考facebook的SlowFast工程2:
引用库
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix
计算混淆矩阵
从pytorch模型输出的预测结果preds、真值labels,计算混淆矩阵。
def get_confusion_matrix(preds, labels, num_classes, normalize="true"):
"""
Calculate confusion matrix on the provided preds and labels.
Args:
preds (tensor or lists of tensors): predictions. Each tensor is in
in the shape of (n_batch, num_classes). Tensor(s) must be on CPU.
labels (tensor or lists of tensors): corresponding labels. Each tensor is
in the shape of either (n_batch,) or (n_batch, num_classes).
num_classes (int): number of classes. Tensor(s) must be on CPU.
normalize (Optional[str]) : {‘true’, ‘pred’, ‘all’}, default="true"
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix
will not be normalized.
Returns:
cmtx (ndarray): confusion matrix of size (num_classes x num_classes)
"""
if isinstance(preds, list):
preds = torch.cat(preds, dim=0)
if isinstance(labels, list):
labels = torch.cat(labels, dim=0)
# If labels are one-hot encoded, get their indices.
if labels.ndim == preds.ndim:
labels = torch.argmax(labels, dim=-1)
# Get the predicted class indices for examples.
preds = torch.flatten(torch.argmax(preds, dim=-1))
labels = torch.

本文介绍如何在TensorBoard中可视化混淆矩阵,通过Python代码实现混淆矩阵的计算与绘制,并将其添加到TensorBoard中以方便地观察分类模型的表现。
最低0.47元/天 解锁文章
5543





