〖MMDetection〗解析文件:mmdet/models/losses/cross_entropy_loss.py

标题:深入剖析 MMDetection 中的 CrossEntropyLoss

在目标检测和分类任务中,损失函数的选择对于模型的性能至关重要。本文将详细解析 MMDetection 框架中的 CrossEntropyLossCrossEntropyCustomLoss,帮助读者深入理解其工作原理和应用场景。

一、引入相关模块和注册模块

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmdet.registry import MODELS
from.accuracy import accuracy
from.utils import weight_reduce_loss

这段代码首先引入了必要的模块。warnings用于显示警告信息。torch是 PyTorch 库,用于深度学习计算。torch.nn包含神经网络模块。torch.nn.functional提供了一些常用的函数,如激活函数等。mmdet.registry.MODELS用于注册自定义模块。从当前目录下的accuracy模块中引入了accuracy函数,可能用于计算模型的准确率。从utils模块中引入了weight_reduce_loss函数,用于对损失进行加权和归约操作。

二、定义交叉熵损失相关函数

  1. cross_entropy函数:
def cross_entropy(pred,
                  label,
                  weight=None,
                  reduction='mean',
                  avg_factor=None,
                  class_weight=None,
                  ignore_index=-100,
                  avg_non_ignore=False):
    """Calculate the CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        label (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (int | None): The label index to be ignored.
            If None, it will be set to default value. Default: -100.
        avg_non_ignore (bool): The flag decides to whether the loss is
            only averaged over non-ignored targets. Default: False.

    Returns:
        torch.Tensor: The calculated loss
    """
    # The default value of ignore_index is the same as F.cross_entropy
    ignore_index = -100 if ignore_index is None else ignore_index
    # element-wise losses
    loss = F.cross_entropy(
        pred,
        label,
        weight=class_weight,
        reduction='none',
        ignore_index=ignore_index)

    # average loss over non-ignored elements
    # pytorch's official cross_entropy average loss over non-ignored elements
    # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660  # noqa
    if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
        avg_factor = label.numel() - (label == ignore_index).sum().item()

    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    return loss

这个函数用于计算交叉熵损失。它接受预测值pred、标签label以及一些可选参数,如权重weight、归约方法reduction、平均因子avg_factor、类别权重class_weight和忽略索引ignore_index等。首先,使用F.cross_entropy计算元素级别的损失,然后根据参数进行加权和归约操作,最后返回计算得到的损失。

以下是对这段代码的详细解析:

函数功能

这个函数用于计算交叉熵损失(CrossEntropy loss)。它接受预测值、真实标签以及一些可选参数,返回计算得到的损失值。

参数解释

  • predtorch.Tensor类型,形状为(N, C),表示模型的预测结果,其中N是样本数量,C是类别数量。
  • labeltorch.Tensor类型,真实的学习标签,与预测结果对应。
  • weighttorch.Tensor类型,可选参数,表示每个样本的损失权重。
  • reduction:字符串类型,可选参数,用于指定损失的归约方法,可以是'none'(不进行归约)、'mean'(求平均)或'sum'(求和),默认为'mean'
  • avg_factor:整数类型,可选参数,用于平均损失的因子。
  • class_weight:列表类型,可选参数,表示每个类别的权重。
  • ignore_index:整数或None,可选参数,表示要忽略的标签索引,如果为None,则设置为默认值-100
  • avg_non_ignore:布尔类型,可选参数,表示是否只对非忽略的目标进行平均损失计算,默认为False

函数执行过程

  1. 处理默认参数和忽略索引:

    • 如果ignore_indexNone,则将其设置为默认值-100,这个值与 PyTorch 中F.cross_entropy函数的默认忽略索引相同。
  2. 计算元素级别的损失:

    • 使用F.cross_entropy函数计算元素级别的交叉熵损失,设置reduction='none'表示不进行归约操作,同时可以指定类别权重class_weight和忽略索引ignore_index
  3. 处理平均因子和非忽略目标平均:

    • 如果avg_factorNoneavg_non_ignoreTruereduction'mean',则计算平均因子avg_factor,它等于标签总数减去等于忽略索引的标签数量。
  4. 应用权重和进行归约:

    • 如果weight不为None,将其转换为浮点类型。
    • 使用weight_reduce_loss函数对损失进行加权和归约操作,传入计算得到的损失、权重、归约方法和平均因子。
  5. 返回损失值:

    • 返回最终计算得到的损失值。

总的来说,这个函数提供了一种灵活的方式来计算交叉熵损失,可以根据不同的需求进行参数配置,包括指定损失的归约方法、样本权重、类别权重以及是否对非忽略目标进行平均等。

  1. _expand_onehot_labels函数:
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
    """Expand onehot labels to match the size of prediction."""
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    valid_mask = (labels >= 0) & (labels != ignore_index)
    inds = torch.nonzero(
        valid_mask & (labels < label_channels), as_tuple=False)

    if inds.numel() > 0:
        bin_labels[inds, labels[inds]] = 1

    valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
                                               label_channels).float()
    if label_weights is None:
        bin_label_weights = valid_mask
    else:
        bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
        bin_label_weights *= valid_mask

    return bin_labels, bin_label_weights, valid_mask

这个辅助函数用于将标签扩展为 one-hot 格式,以匹配预测值的尺寸。它接受标签、标签权重、标签通道数和忽略索引作为参数,返回扩展后的标签、权重和有效掩码。

以下是对这段代码的详细解析:

函数功能

这个函数的目的是将标签扩展为 one-hot 编码格式,以匹配预测值的尺寸。它接受原始标签、标签权重、标签通道数和忽略索引作为参数,返回扩展后的 one-hot 标签、对应的权重以及有效掩码。

参数解释

  • labelstorch.Tensor类型,原始的标签张量。
  • label_weightstorch.Tensor类型或None,可选的标签权重张量。
  • label_channels:整数类型,表示标签的通道数(类别数)。
  • ignore_index:整数类型,表示要忽略的标签索引。

函数执行过程

  1. 创建全零的初始标签张量:

    • bin_labels = labels.new_full((labels.size(0), label_channels), 0)创建一个与原始标签张量的第一维大小相同(样本数量)且通道数为label_channels的全零张量。
  2. 计算有效掩码:

    • valid_mask = (labels >= 0) & (labels!= ignore_index)创建一个布尔掩码,用于标识哪些标签是有效的(既不小于 0 也不等于忽略索引)。
  3. 找到有效标签的索引:

    • inds = torch.nonzero(valid_mask & (labels < label_channels), as_tuple=False)找到同时满足有效掩码条件且标签值小于标签通道数的非零元素的索引。
  4. 设置 one-hot 标签:

    • 如果有效索引的数量大于 0,即存在有效的标签,那么将bin_labels在这些有效索引位置上设置为 1,即bin_labels[inds, labels[inds]] = 1
  5. 扩展有效掩码:

    • valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), label_channels).float()将有效掩码从形状为(N,)扩展为形状为(N, label_channels)的浮点型张量。
  6. 处理标签权重:

    • 如果label_weightsNone,则将bin_label_weights设置为扩展后的有效掩码。
    • 如果label_weights不为None,则先将其形状从(N,)扩展为(N, 1)并重复label_channels次,然后与扩展后的有效掩码相乘,得到最终的bin_label_weights
  7. 返回结果:

    • 返回扩展后的 one-hot 标签bin_labels、对应的权重bin_label_weights以及有效掩码valid_mask

总的来说,这个函数在处理多分类问题时非常有用,特别是当需要将原始标签转换为与预测值形状匹配的 one-hot 编码格式时,可以方便地进行损失计算和其他操作。

  1. binary_cross_entropy函数:
def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None,
                         class_weight=None,
                         ignore_index=-100,
                         avg_non_ignore=False):
    """Calculate the binary CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
            When the shape of pred is (N, 1), label will be expanded to
            one-hot format, and when the shape of pred is (N, ), label
            will not be expanded to one-hot format.
        label (torch.Tensor): The learning label of the prediction,
            with shape (N, ).
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (int | None): The label index to be ignored.
            If None, it will be set to default value. Default: -100.
        avg_non_ignore (bool): The flag decides to whether the loss is
            only averaged over non-ignored targets. Default: False.

    Returns:
        torch.Tensor: The calculated loss.
    """
    # The default value of ignore_index is the same as F.cross_entropy
    ignore_index = -100 if ignore_index is None else ignore_index

    if pred.dim() != label.dim():
        label, weight, valid_mask = _expand_onehot_labels(
            label, weight, pred.size(-1), ignore_index)
    else:
        # should mask out the ignored elements
        valid_mask = ((label >= 0) & (label != ignore_index)).float()
        if weight is not None:
            # The inplace writing method will have a mismatched broadcast
            # shape error if the weight and valid_mask dimensions
            # are inconsistent such as (B,N,1) and (B,N,C).
            weight = weight * valid_mask
        else:
            weight = valid_mask

    # average loss over non-ignored elements
    if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
        avg_factor = valid_mask.sum().item()

    # weighted element-wise losses
    weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), pos_weight=class_weight, reduction='none')
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(
        loss, weight, reduction=reduction, avg_factor=avg_factor)

    return loss

这个函数用于计算二值交叉熵损失。它根据预测值和标签的维度情况,进行标签的扩展和权重的处理,然后使用F.binary_cross_entropy_with_logits计算损失,并进行加权和归约操作,最后返回计算得到的损失。

以下是对这段代码的详细解析:

函数功能

这个函数用于计算二值交叉熵损失(binary CrossEntropy loss)。它接受预测值、真实标签以及一些可选参数,返回计算得到的损失值。

参数解释

  • predtorch.Tensor类型,预测值张量,形状可以是(N, 1)(N,),分别对应不同的处理方式。
  • labeltorch.Tensor类型,真实标签张量,形状为(N,)
  • weighttorch.Tensor类型,可选的样本损失权重。
  • reduction:字符串类型,可选参数,用于指定损失的归约方法,可以是'none''mean''sum'
  • avg_factor:整数类型,可选参数,用于平均损失的因子。
  • class_weight:列表类型,可选参数,表示正类的权重。
  • ignore_index:整数或None,可选参数,表示要忽略的标签索引,若为None则设为默认值-100
  • avg_non_ignore:布尔类型,可选参数,表示是否只对非忽略的目标进行平均损失计算。

函数执行过程

  1. 处理默认参数和忽略索引:

    • 如果ignore_indexNone,则将其设置为默认值-100,这个值与F.cross_entropy函数的默认忽略索引相同。
  2. 根据预测值和标签维度处理标签扩展和权重:

    • 如果预测值的维度与标签的维度不同:
      • 调用_expand_onehot_labels函数将标签扩展为 one-hot 格式,并同时处理权重和有效掩码。这个函数的作用是将标签转换为与预测值形状匹配的 one-hot 编码形式,以便进行后续的损失计算。
    • 如果预测值的维度与标签的维度相同:
      • 计算有效掩码valid_mask,通过判断标签是否大于等于 0 且不等于忽略索引来确定哪些标签是有效的,并将其转换为浮点型。
      • 如果权重存在,将权重与有效掩码相乘;如果权重不存在,则直接将有效掩码作为权重。
  3. 处理平均因子和非忽略目标平均:

    • 如果avg_factorNoneavg_non_ignoreTruereduction'mean',则将平均因子avg_factor设置为有效掩码的总和(即非忽略目标的数量)。
  4. 计算加权的元素级别的损失:

    • 将权重转换为浮点型。
    • 使用F.binary_cross_entropy_with_logits函数计算二值交叉熵损失,其中pred是未经激活的预测值,label.float()是浮点型的标签,pos_weight=class_weight表示正类的权重,设置reduction='none'表示不进行归约操作。
  5. 对加权损失进行归约:

    • 使用weight_reduce_loss函数对加权损失进行归约操作,传入计算得到的损失、权重、归约方法和平均因子。
  6. 返回损失值:

    • 返回最终计算得到的损失值。

总的来说,这个函数提供了一种计算二值交叉熵损失的方法,可以根据不同的参数配置进行灵活的计算,包括处理不同维度的预测值和标签、指定损失归约方法、考虑样本权重和非忽略目标平均等情况。

  1. mask_cross_entropy函数:
def mask_cross_entropy(pred,
                       target,
                       label,
                       reduction='mean',
                       avg_factor=None,
                       class_weight=None,
                       ignore_index=None,
                       **kwargs):
    """Calculate the CrossEntropy loss for masks.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C, *), C is the
            number of classes. The trailing * indicates arbitrary shape.
        target (torch.Tensor): The learning label of the prediction.
        label (torch.Tensor): ``label`` indicates the class label of the mask
            corresponding object. This will be used to select the mask in the
            of the class which the object belongs to when the mask prediction
            if not class-agnostic.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (None): Placeholder, to be consistent with other loss.
            Default: None.

    Returns:
        torch.Tensor: The calculated loss

    Example:
        >>> N, C = 3, 11
        >>> H, W = 2, 2
        >>> pred = torch.randn(N, C, H, W) * 1000
        >>> target = torch.rand(N, H, W)
        >>> label = torch.randint(0, C, size=(N,))
        >>> reduction = 'mean'
        >>> avg_factor = None
        >>> class_weights = None
        >>> loss = mask_cross_entropy(pred, target, label, reduction,
        >>>                           avg_factor, class_weights)
        >>> assert loss.shape == (1,)
    """
    assert ignore_index is None, 'BCE loss does not support ignore_index'
    # TODO: handle these two reserved arguments
    assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
        pred_slice, target, weight=class_weight, reduction='mean')[None]

这个函数用于计算掩码交叉熵损失。它接受预测值、目标值、标签以及一些参数,计算过程中会根据标签选择对应的预测值切片,然后使用F.binary_cross_entropy_with_logits计算损失,并进行归约操作,最后返回计算得到的损失。

以下是对这段代码的详细解析:

函数功能

这个函数用于计算掩码交叉熵损失(CrossEntropy loss for masks)。它接受预测值、目标值和标签作为输入,并返回计算得到的损失值。主要用于处理与掩码相关的任务,例如图像分割中的掩码预测。

参数解释

  • predtorch.Tensor类型,预测值张量,形状为(N, C, *),其中N表示样本数量,C是类别数量,*表示任意数量的额外维度。
  • targettorch.Tensor类型,预测的学习标签张量。
  • labeltorch.Tensor类型,指示掩码对应对象的类别标签。在掩码预测不是类别无关的情况下,用于选择属于该对象类别的掩码。
  • reduction:字符串类型,可选参数,用于指定损失的归约方法,这里只支持'mean',表示求平均。
  • avg_factor:整数类型,可选参数,用于平均损失的因子,这里默认为None,但在函数内部有特定处理。
  • class_weight:列表类型,可选参数,表示每个类别的权重。
  • ignore_index:默认为None,这里断言其为None,因为这个损失函数不支持忽略索引。

函数执行过程

  1. 断言和处理参数:

    • assert ignore_index is None, 'BCE loss does not support ignore_index',断言忽略索引必须为None,因为这个损失函数不支持忽略特定索引的情况。
    • assert reduction == 'mean' and avg_factor is None,断言归约方法必须为'mean'且平均因子必须为None,这是对参数的限制。
  2. 选择特定的预测切片:

    • num_rois = pred.size()[0],获取预测值的第一维大小,即样本数量。
    • inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device),创建一个从 0 到num_rois - 1的整数序列张量,用于索引。
    • pred_slice = pred[inds, label].squeeze(1),根据索引和标签选择预测值中的特定切片,并通过squeeze(1)去除可能存在的维度为 1 的维度。
  3. 计算损失并返回:

    • 使用F.binary_cross_entropy_with_logits函数计算二值交叉熵损失,输入为选择的预测切片pred_slice和目标值target,可以指定类别权重class_weight,设置归约方法为'mean'
    • [None]将损失值的形状从标量扩展为形状为(1,)的一维张量并返回。

总的来说,这个函数专门用于处理掩码相关任务中的交叉熵损失计算,具有特定的参数限制和计算逻辑。

三、定义 CrossEntropyLoss 类

@MODELS.register_module()
class CrossEntropyLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=False,
                 use_mask=False,
                 reduction='mean',
                 class_weight=None,
                 ignore_index=None,
                 loss_weight=1.0,
                 avg_non_ignore=False):
        """CrossEntropyLoss.

        Args:
            use_sigmoid (bool, optional): Whether the prediction uses sigmoid
                of softmax. Defaults to False.
            use_mask (bool, optional): Whether to use mask cross entropy loss.
                Defaults to False.
            reduction (str, optional): . Defaults to 'mean'.
                Options are "none", "mean" and "sum".
            class_weight (list[float], optional): Weight of each class.
                Defaults to None.
            ignore_index (int | None): The label index to be ignored.
                Defaults to None.
            loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
            avg_non_ignore (bool): The flag decides to whether the loss is
                only averaged over non-ignored targets. Default: False.
        """
        super(CrossEntropyLoss, self).__init__()
        assert (use_sigmoid is False) or (use_mask is False)
        self.use_sigmoid = use_sigmoid
        self.use_mask = use_mask
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.class_weight = class_weight
        self.ignore_index = ignore_index
        self.avg_non_ignore = avg_non_ignore
        if ((ignore_index is not None) and not self.avg_non_ignore
                and self.reduction == 'mean'):
            warnings.warn(
                'Default ``avg_non_ignore`` is False, if you would like to '
                'ignore the certain label and average loss over non-ignore '
                'labels, which is the same with PyTorch official '
                'cross_entropy, set ``avg_non_ignore=True``.')

        if self.use_sigmoid:
            self.cls_criterion = binary_cross_entropy
        elif self.use_mask:
            self.cls_criterion = mask_cross_entropy
        else:
            self.cls_criterion = cross_entropy

这个类继承自nn.Module,是交叉熵损失的主要实现类。构造函数接受一些参数,用于配置损失函数的行为,如是否使用 sigmoid 激活、是否使用掩码、归约方法、类别权重、忽略索引、损失权重和是否仅对非忽略目标求平均等。在构造函数中,根据参数确定使用哪种具体的损失计算函数,并进行一些警告和初始化操作。

以下是对这段代码的详细解析:

类功能及装饰器说明

CrossEntropyLoss是一个自定义的交叉熵损失类,它继承自nn.Module,表示这是一个可在 PyTorch 中作为神经网络模块使用的类。@MODELS.register_module()是一个装饰器,用于将这个类注册到一个特定的模块注册表中,以便在其他地方可以方便地通过名称来调用和实例化这个类。

构造函数参数解释

  • use_sigmoid:布尔类型,可选参数,默认为False。表示预测是否使用 sigmoid 激活函数而不是 softmax。如果为True,则在计算损失时会有不同的处理。
  • use_mask:布尔类型,可选参数,默认为False。表示是否使用掩码交叉熵损失。如果为True,会采用特定的损失计算方式。
  • reduction:字符串类型,可选参数,默认为'mean'。表示损失的归约方法,可以是'none'(不进行归约)、'mean'(求平均)或'sum'(求和)。
  • class_weight:列表类型,可选参数,默认为None。表示每个类别的权重。
  • ignore_index:整数或None,可选参数,默认为None。表示要忽略的标签索引。
  • loss_weight:浮点数类型,可选参数,默认为1.0。表示损失的权重。
  • avg_non_ignore:布尔类型,默认为False。表示损失是否只在非忽略的目标上进行平均。

构造函数执行过程

  1. 调用父类构造函数:

    • super(CrossEntropyLoss, self).__init__()调用父类nn.Module的构造函数,进行必要的初始化操作。
  2. 断言和参数设置:

    • assert (use_sigmoid is False) or (use_mask is False)确保use_sigmoiduse_mask不能同时为True,因为这两种方式是互斥的选择。
    • 分别将传入的参数赋值给类的属性,如self.use_sigmoidself.use_maskself.reduction等。
  3. 警告提示:

    • 如果ignore_index不为Noneavg_non_ignoreFalsereduction'mean',则发出一个警告,提示用户如果想要忽略特定标签并在非忽略的标签上平均损失,可以将avg_non_ignore设置为True,这与 PyTorch 官方的交叉熵损失计算方式一致。
  4. 根据参数选择损失计算准则:

    • 如果self.use_sigmoidTrue,则将self.cls_criterion设置为binary_cross_entropy函数,表示使用二值交叉熵损失计算方式。
    • 如果self.use_maskTrue,则将self.cls_criterion设置为mask_cross_entropy函数,表示使用掩码交叉熵损失计算方式。
    • 如果都不是,则将self.cls_criterion设置为cross_entropy函数,表示使用普通的交叉熵损失计算方式。

总的来说,这个类提供了一种灵活的方式来配置和计算交叉熵损失,可以根据不同的需求选择不同的损失计算方式和参数设置。

    def extra_repr(self):
        s = f'avg_non_ignore={self.avg_non_ignore}'
        return s

这个方法用于提供额外的字符串表示,在打印对象时可以显示更多信息。

    def forward(self,
                cls_score,
                label,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                ignore_index=None,
                **kwargs):
        """Forward function.

        Args:
            cls_score (torch.Tensor): The prediction.
            label (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): Sample-wise loss weight.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The method used to reduce the
                loss. Options are "none", "mean" and "sum".
            ignore_index (int | None): The label index to be ignored.
                If not None, it will override the default value. Default: None.
        Returns:
            torch.Tensor: The calculated loss.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if ignore_index is None:
            ignore_index = self.ignore_index

        if self.class_weight is not None:
            class_weight = cls_score.new_tensor(
                self.class_weight, device=cls_score.device)
        else:
            class_weight = None
        loss_cls = self.loss_weight * self.cls_criterion(
            cls_score,
            label,
            weight,
            class_weight=class_weight,
            reduction=reduction,
            avg_factor=avg_factor,
            ignore_index=ignore_index,
            avg_non_ignore=self.avg_non_ignore,
            **kwargs)
        return loss_cls

这个方法是模型的前向传播方法。它接受预测值、标签以及一些参数,根据参数确定损失的归约方法和忽略索引,然后调用相应的损失计算函数,并乘以损失权重,最后返回计算得到的损失。

以下是对这段代码的详细解析:

函数功能

这是一个类方法,用于计算交叉熵损失的前向传播过程。它接受预测值、真实标签以及一些可选参数,通过调用特定的损失计算准则(cls_criterion)来计算损失,并根据配置进行加权和处理,最后返回计算得到的损失值。

参数解释

  • cls_scoretorch.Tensor类型,模型的预测结果。
  • labeltorch.Tensor类型,预测结果的真实学习标签。
  • weighttorch.Tensor类型,可选参数,样本级别的损失权重。
  • avg_factor:整数类型,可选参数,用于平均损失的因子。
  • reduction_override:字符串类型,可选参数,用于覆盖默认的损失归约方法,可选值为None'none''mean''sum'
  • ignore_index:整数或None,可选参数,要忽略的标签索引,如果不为None,将覆盖默认的忽略索引。
  • **kwargs:额外的关键字参数,用于传递给损失计算准则。

函数执行过程

  1. 参数检查和设置:

    • assert reduction_override in (None, 'none', 'mean', 'sum'),确保reduction_override参数的值是合法的。
    • 如果reduction_override不为None,则使用它作为损失归约方法;否则,使用类的属性self.reduction作为归约方法。
    • 如果ignore_indexNone,则将其设置为类的属性self.ignore_index
  2. 处理类别权重:

    • 如果类有定义的类别权重self.class_weight,则创建一个新的张量class_weight,其值为类别权重,并放置在与cls_score相同的设备上。
    • 如果没有定义类别权重,则将class_weight设置为None
  3. 计算损失:

    • 使用self.cls_criterion(根据类的配置确定的具体损失计算准则)来计算损失。传入预测值cls_score、标签label、权重weight、类别权重class_weight、归约方法reduction、平均因子avg_factor、忽略索引ignore_index以及avg_non_ignore属性和额外的关键字参数**kwargs
    • 将计算得到的损失乘以类的属性self.loss_weight,得到最终的损失值loss_cls
  4. 返回损失值:

    • 返回计算得到的损失值loss_cls

总的来说,这个方法在交叉熵损失类中起到了核心的前向计算作用,通过灵活的参数配置和调用特定的损失计算准则,能够适应不同的任务需求。

四、定义 CrossEntropyCustomLoss 类

@MODELS.register_module()
class CrossEntropyCustomLoss(CrossEntropyLoss):

    def __init__(self,
                 use_sigmoid=False,
                 use_mask=False,
                 reduction='mean',
                 num_classes=-1,
                 class_weight=None,
                 ignore_index=None,
                 loss_weight=1.0,
                 avg_non_ignore=False):
        """CrossEntropyCustomLoss.

        Args:
            use_sigmoid (bool, optional): Whether the prediction uses sigmoid
                of softmax. Defaults to False.
            use_mask (bool, optional): Whether to use mask cross entropy loss.
                Defaults to False.
            reduction (str, optional): . Defaults to 'mean'.
                Options are "none", "mean" and "sum".
            num_classes (int): Number of classes to classify.
            class_weight (list[float], optional): Weight of each class.
                Defaults to None.
            ignore_index (int | None): The label index to be ignored.
                Defaults to None.
            loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
            avg_non_ignore (bool): The flag decides to whether the loss is
                only averaged over non-ignored targets. Default: False.
        """
        super(CrossEntropyCustomLoss, self).__init__()
        assert (use_sigmoid is False) or (use_mask is False)
        self.use_sigmoid = use_sigmoid
        self.use_mask = use_mask
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.class_weight = class_weight
        self.ignore_index = ignore_index
        self.avg_non_ignore = avg_non_ignore
        if ((ignore_index is not None) and not self.avg_non_ignore
                and self.reduction == 'mean'):
            warnings.warn(
                'Default ``avg_non_ignore`` is False, if you would like to '
                'ignore the certain label and average loss over non-ignore '
                'labels, which is the same with PyTorch official '
                'cross_entropy, set ``avg_non_ignore=True``.')

        if self.use_sigmoid:
            self.cls_criterion = binary_cross_entropy
        elif self.use_mask:
            self.cls_criterion = mask_cross_entropy
        else:
            self.cls_criterion = cross_entropy

        self.num_classes = num_classes

        assert self.num_classes != -1

        # custom output channels of the classifier
        self.custom_cls_channels = True
        # custom activation of cls_score
        self.custom_activation = True
        # custom accuracy of the classsifier
        self.custom_accuracy = True

这个类继承自CrossEntropyLoss,是一个自定义的交叉熵损失类。构造函数除了接受与CrossEntropyLoss相同的参数外,还接受一个num_classes参数,表示分类的类别数。在构造函数中,除了进行与父类相同的初始化操作外,还设置了一些自定义的属性,如custom_cls_channelscustom_activationcustom_accuracy等。

以下是对这段代码的详细解析:

类功能及装饰器说明

CrossEntropyCustomLoss是一个自定义的交叉熵损失类,它继承自CrossEntropyLoss类。和前面的CrossEntropyLoss一样,它也被@MODELS.register_module()装饰器注册到特定的模块注册表中,以便在其他地方可以方便地调用和实例化。

构造函数参数解释

  • use_sigmoid:布尔类型,可选参数,默认为False。表示预测是否使用 sigmoid 激活函数。
  • use_mask:布尔类型,可选参数,默认为False。表示是否使用掩码交叉熵损失。
  • reduction:字符串类型,可选参数,默认为'mean'。表示损失的归约方法。
  • num_classes:整数类型,表示要分类的类别数量。
  • class_weight:列表类型,可选参数,默认为None。表示每个类别的权重。
  • ignore_index:整数或None,可选参数,默认为None。表示要忽略的标签索引。
  • loss_weight:浮点数类型,可选参数,默认为1.0。表示损失的权重。
  • avg_non_ignore:布尔类型,默认为False。表示损失是否只在非忽略的目标上进行平均。

构造函数执行过程

  1. 调用父类构造函数:

    • super(CrossEntropyCustomLoss, self).__init__()调用父类CrossEntropyLoss的构造函数,进行一些基本的初始化操作。
  2. 断言和参数设置:

    • CrossEntropyLoss类似,确保use_sigmoiduse_mask不能同时为True
    • 将传入的参数赋值给类的属性,如self.use_sigmoidself.use_maskself.reduction等。
  3. 警告提示:

    • 如果ignore_index不为Noneavg_non_ignoreFalsereduction'mean',发出警告提示用户可以设置avg_non_ignoreTrue以实现与 PyTorch 官方交叉熵损失相同的行为。
  4. 根据参数选择损失计算准则:

    • CrossEntropyLoss类似,根据use_sigmoiduse_mask的值选择合适的损失计算函数赋值给self.cls_criterion
  5. 额外的属性设置:

    • self.num_classes = num_classes,将传入的类别数量赋值给类的属性。
    • assert self.num_classes!= -1,确保传入的类别数量不是 -1。
    • 设置三个自定义属性为True,分别表示分类器的自定义输出通道、自定义激活函数和自定义准确率计算。

总的来说,这个类在继承CrossEntropyLoss的基础上,增加了对特定类别数量的要求,并设置了一些自定义的属性,用于在特定的场景下进行更灵活的交叉熵损失计算和处理。

    def get_cls_channels(self, num_classes):
        assert num_classes == self.num_classes
        if not self.use_sigmoid:
            return num_classes + 1
        else:
            return num_classes

这个方法用于获取分类器的输出通道数,根据是否使用 sigmoid 激活进行不同的计算。

    def get_activation(self, cls_score):

        fine_cls_score = cls_score[:, :self.num_classes]

        if not self.use_sigmoid:
            bg_score = cls_score[:, [-1]]
            new_score = torch.cat([fine_cls_score, bg_score], dim=-1)
            scores = F.softmax(new_score, dim=-1)
        else:
            score_classes = fine_cls_score.sigmoid()
            score_neg = 1 - score_classes.sum(dim=1, keepdim=True)
            score_neg = score_neg.clamp(min=0, max=1)
            scores = torch.cat([score_classes, score_neg], dim=1)

        return scores

这个方法用于获取分类器的激活函数输出,根据是否使用 sigmoid 激活进行不同的处理。

以下是对这段代码的详细解析:

方法功能

这个方法用于获取分类器的激活输出。它根据不同的配置(是否使用 sigmoid)对输入的分类得分进行处理,得到最终的激活后的分数。

参数解释

  • cls_scoretorch.Tensor类型,分类器的输出得分张量。

方法执行过程

  1. 提取精细分类得分:

    • fine_cls_score = cls_score[:, :self.num_classes],从输入的分类得分中提取出前self.num_classes个类别对应的得分,即精细分类得分。
  2. 根据是否使用 sigmoid 进行不同处理:

    • 如果不使用 sigmoid:

      • bg_score = cls_score[:, [-1]],获取背景类别的得分,这里使用了切片[-1]来获取最后一个维度上的最后一个元素。
      • new_score = torch.cat([fine_cls_score, bg_score], dim=-1),将精细分类得分和背景得分在最后一个维度上进行拼接。
      • scores = F.softmax(new_score, dim=-1),对拼接后的得分进行 softmax 激活,得到最终的激活后的分数。
    • 如果使用 sigmoid:

      • score_classes = fine_cls_score.sigmoid(),对精细分类得分进行 sigmoid 激活,得到各个类别的激活得分。
      • score_neg = 1 - score_classes.sum(dim=1, keepdim=True),计算除了正类之外的其他类别的总和的补集,即负类的得分。这里使用了sum函数在维度 1 上进行求和,并通过keepdim=True保持维度。
      • score_neg = score_neg.clamp(min=0, max=1),将负类得分限制在 0 到 1 的范围内。
      • scores = torch.cat([score_classes, score_neg], dim=1),将正类得分和负类得分在维度 1 上进行拼接,得到最终的激活后的分数。
  3. 返回激活后的分数:

    • 返回计算得到的激活后的分数scores

总的来说,这个方法提供了一种根据不同的配置来处理分类器输出得分并得到激活后分数的方式,可用于后续的计算和评估。

    def get_accuracy(self, cls_score, labels):

        fine_cls_score = cls_score[:, :self.num_classes]

        pos_inds = labels < self.num_classes
        acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds])
        acc = dict()
        acc['acc_classes'] = acc_classes
        return acc

这个方法用于获取分类器的准确率,根据标签选择有效的预测值进行准确率计算。
以下是对这段代码的详细解析:

方法功能

这个方法用于计算分类器的准确率。它根据输入的分类得分和真实标签,提取出正类的分类得分,并计算在这些正类上的准确率,最后以字典的形式返回准确率结果。

参数解释

  • cls_scoretorch.Tensor类型,分类器的输出得分张量。
  • labelstorch.Tensor类型,真实的标签张量。

方法执行过程

  1. 提取精细分类得分:

    • fine_cls_score = cls_score[:, :self.num_classes],从输入的分类得分中提取出前self.num_classes个类别对应的得分,即精细分类得分。
  2. 确定正类索引:

    • pos_inds = labels < self.num_classes,创建一个布尔张量,其中值为True的位置表示对应的标签小于self.num_classes,即属于正类。
  3. 计算正类准确率:

    • acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds]),调用accuracy函数(假设这是一个自定义的准确率计算函数),传入正类的分类得分(通过索引pos_indsfine_cls_score中选择)和正类的标签(同样通过索引pos_indslabels中选择),计算正类的准确率。
  4. 构建准确率字典并返回:

    • acc = dict()创建一个空字典。
    • acc['acc_classes'] = acc_classes将正类准确率添加到字典中,键为'acc_classes'
    • return acc返回包含正类准确率的字典。

综上所述,CrossEntropyLossCrossEntropyCustomLoss提供了多种交叉熵损失的计算方式,可以根据不同的任务需求进行配置和使用。通过注册为 MMDetection 框架中的模块,可以方便地在模型中调用和配置这些损失函数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值