Pytorch+Tensorboard混淆矩阵可视化

本文介绍如何在TensorBoard中可视化混淆矩阵,通过Python代码实现混淆矩阵的计算与绘制,并将其添加到TensorBoard中以方便地观察分类模型的表现。

引言

混淆矩阵是分类任务常用的一种评估方法。对角线元素表示预测标签等于真实标签的点数,而非对角线元素则是分类器未正确标记的点的数量。 混淆矩阵的对角线值越高越好,表明有许多正确的预测。1

尤其是在类别数量不平衡的情况下,相比accuracy,混淆矩阵(confusion matrix)对哪个类被错误分类具有更直观的解释

在平时做简单的数据实验时,可以仅用from sklearn.metrics import plot_confusion_matrix或者seaborn对混淆矩阵进行可视化。但是在深度学习训练模型的过程中,在tensorboard中可视化混淆矩阵会更方便结果记录和对照。

混淆矩阵

在tensorboard中的可视化效果:
在这里插入图片描述

代码实现

代码参考facebook的SlowFast工程2

引用库

import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix

计算混淆矩阵

从pytorch模型输出的预测结果preds、真值labels,计算混淆矩阵。

def get_confusion_matrix(preds, labels, num_classes, normalize="true"):
    """
    Calculate confusion matrix on the provided preds and labels.
    Args:
        preds (tensor or lists of tensors): predictions. Each tensor is in
            in the shape of (n_batch, num_classes). Tensor(s) must be on CPU.
        labels (tensor or lists of tensors): corresponding labels. Each tensor is
            in the shape of either (n_batch,) or (n_batch, num_classes).
        num_classes (int): number of classes. Tensor(s) must be on CPU.
        normalize (Optional[str]) : {‘true’, ‘pred’, ‘all’}, default="true"
            Normalizes confusion matrix over the true (rows), predicted (columns)
            conditions or all the population. If None, confusion matrix
            will not be normalized.
    Returns:
        cmtx (ndarray): confusion matrix of size (num_classes x num_classes)
    """
    if isinstance(preds, list):
        preds = torch.cat(preds, dim=0)
    if isinstance(labels, list):
        labels = torch.cat(labels, dim=0)
    # If labels are one-hot encoded, get their indices.
    if labels.ndim == preds.ndim:
        labels = torch.argmax(labels, dim=-1)
    # Get the predicted class indices for examples.
    preds = torch.flatten(torch.argmax(preds, dim=-1))
    labels = torch.
评论 7
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值