标题:深入剖析 MMDetection 中的 CrossEntropyLoss
在目标检测和分类任务中,损失函数的选择对于模型的性能至关重要。本文将详细解析 MMDetection 框架中的 CrossEntropyLoss
和 CrossEntropyCustomLoss
,帮助读者深入理解其工作原理和应用场景。
一、引入相关模块和注册模块
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.registry import MODELS
from.accuracy import accuracy
from.utils import weight_reduce_loss
这段代码首先引入了必要的模块。warnings
用于显示警告信息。torch
是 PyTorch 库,用于深度学习计算。torch.nn
包含神经网络模块。torch.nn.functional
提供了一些常用的函数,如激活函数等。mmdet.registry.MODELS
用于注册自定义模块。从当前目录下的accuracy
模块中引入了accuracy
函数,可能用于计算模型的准确率。从utils
模块中引入了weight_reduce_loss
函数,用于对损失进行加权和归约操作。
二、定义交叉熵损失相关函数
cross_entropy
函数:
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
这个函数用于计算交叉熵损失。它接受预测值pred
、标签label
以及一些可选参数,如权重weight
、归约方法reduction
、平均因子avg_factor
、类别权重class_weight
和忽略索引ignore_index
等。首先,使用F.cross_entropy
计算元素级别的损失,然后根据参数进行加权和归约操作,最后返回计算得到的损失。
以下是对这段代码的详细解析:
函数功能:
这个函数用于计算交叉熵损失(CrossEntropy loss)。它接受预测值、真实标签以及一些可选参数,返回计算得到的损失值。
参数解释:
pred
:torch.Tensor
类型,形状为(N, C)
,表示模型的预测结果,其中N
是样本数量,C
是类别数量。label
:torch.Tensor
类型,真实的学习标签,与预测结果对应。weight
:torch.Tensor
类型,可选参数,表示每个样本的损失权重。reduction
:字符串类型,可选参数,用于指定损失的归约方法,可以是'none'
(不进行归约)、'mean'
(求平均)或'sum'
(求和),默认为'mean'
。avg_factor
:整数类型,可选参数,用于平均损失的因子。class_weight
:列表类型,可选参数,表示每个类别的权重。ignore_index
:整数或None
,可选参数,表示要忽略的标签索引,如果为None
,则设置为默认值-100
。avg_non_ignore
:布尔类型,可选参数,表示是否只对非忽略的目标进行平均损失计算,默认为False
。
函数执行过程:
-
处理默认参数和忽略索引:
- 如果
ignore_index
为None
,则将其设置为默认值-100
,这个值与 PyTorch 中F.cross_entropy
函数的默认忽略索引相同。
- 如果
-
计算元素级别的损失:
- 使用
F.cross_entropy
函数计算元素级别的交叉熵损失,设置reduction='none'
表示不进行归约操作,同时可以指定类别权重class_weight
和忽略索引ignore_index
。
- 使用
-
处理平均因子和非忽略目标平均:
- 如果
avg_factor
为None
且avg_non_ignore
为True
且reduction
为'mean'
,则计算平均因子avg_factor
,它等于标签总数减去等于忽略索引的标签数量。
- 如果
-
应用权重和进行归约:
- 如果
weight
不为None
,将其转换为浮点类型。 - 使用
weight_reduce_loss
函数对损失进行加权和归约操作,传入计算得到的损失、权重、归约方法和平均因子。
- 如果
-
返回损失值:
- 返回最终计算得到的损失值。
总的来说,这个函数提供了一种灵活的方式来计算交叉熵损失,可以根据不同的需求进行参数配置,包括指定损失的归约方法、样本权重、类别权重以及是否对非忽略目标进行平均等。
_expand_onehot_labels
函数:
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(
valid_mask & (labels < label_channels), as_tuple=False)
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
label_channels).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights, valid_mask
这个辅助函数用于将标签扩展为 one-hot 格式,以匹配预测值的尺寸。它接受标签、标签权重、标签通道数和忽略索引作为参数,返回扩展后的标签、权重和有效掩码。
以下是对这段代码的详细解析:
函数功能:
这个函数的目的是将标签扩展为 one-hot 编码格式,以匹配预测值的尺寸。它接受原始标签、标签权重、标签通道数和忽略索引作为参数,返回扩展后的 one-hot 标签、对应的权重以及有效掩码。
参数解释:
labels
:torch.Tensor
类型,原始的标签张量。label_weights
:torch.Tensor
类型或None
,可选的标签权重张量。label_channels
:整数类型,表示标签的通道数(类别数)。ignore_index
:整数类型,表示要忽略的标签索引。
函数执行过程:
-
创建全零的初始标签张量:
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
创建一个与原始标签张量的第一维大小相同(样本数量)且通道数为label_channels
的全零张量。
-
计算有效掩码:
valid_mask = (labels >= 0) & (labels!= ignore_index)
创建一个布尔掩码,用于标识哪些标签是有效的(既不小于 0 也不等于忽略索引)。
-
找到有效标签的索引:
inds = torch.nonzero(valid_mask & (labels < label_channels), as_tuple=False)
找到同时满足有效掩码条件且标签值小于标签通道数的非零元素的索引。
-
设置 one-hot 标签:
- 如果有效索引的数量大于 0,即存在有效的标签,那么将
bin_labels
在这些有效索引位置上设置为 1,即bin_labels[inds, labels[inds]] = 1
。
- 如果有效索引的数量大于 0,即存在有效的标签,那么将
-
扩展有效掩码:
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), label_channels).float()
将有效掩码从形状为(N,)
扩展为形状为(N, label_channels)
的浮点型张量。
-
处理标签权重:
- 如果
label_weights
为None
,则将bin_label_weights
设置为扩展后的有效掩码。 - 如果
label_weights
不为None
,则先将其形状从(N,)
扩展为(N, 1)
并重复label_channels
次,然后与扩展后的有效掩码相乘,得到最终的bin_label_weights
。
- 如果
-
返回结果:
- 返回扩展后的 one-hot 标签
bin_labels
、对应的权重bin_label_weights
以及有效掩码valid_mask
。
- 返回扩展后的 one-hot 标签
总的来说,这个函数在处理多分类问题时非常有用,特别是当需要将原始标签转换为与预测值形状匹配的 one-hot 编码格式时,可以方便地进行损失计算和其他操作。
binary_cross_entropy
函数:
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
这个函数用于计算二值交叉熵损失。它根据预测值和标签的维度情况,进行标签的扩展和权重的处理,然后使用F.binary_cross_entropy_with_logits
计算损失,并进行加权和归约操作,最后返回计算得到的损失。
以下是对这段代码的详细解析:
函数功能:
这个函数用于计算二值交叉熵损失(binary CrossEntropy loss)。它接受预测值、真实标签以及一些可选参数,返回计算得到的损失值。
参数解释:
pred
:torch.Tensor
类型,预测值张量,形状可以是(N, 1)
或(N,)
,分别对应不同的处理方式。label
:torch.Tensor
类型,真实标签张量,形状为(N,)
。weight
:torch.Tensor
类型,可选的样本损失权重。reduction
:字符串类型,可选参数,用于指定损失的归约方法,可以是'none'
、'mean'
或'sum'
。avg_factor
:整数类型,可选参数,用于平均损失的因子。class_weight
:列表类型,可选参数,表示正类的权重。ignore_index
:整数或None
,可选参数,表示要忽略的标签索引,若为None
则设为默认值-100
。avg_non_ignore
:布尔类型,可选参数,表示是否只对非忽略的目标进行平均损失计算。
函数执行过程:
-
处理默认参数和忽略索引:
- 如果
ignore_index
为None
,则将其设置为默认值-100
,这个值与F.cross_entropy
函数的默认忽略索引相同。
- 如果
-
根据预测值和标签维度处理标签扩展和权重:
- 如果预测值的维度与标签的维度不同:
- 调用
_expand_onehot_labels
函数将标签扩展为 one-hot 格式,并同时处理权重和有效掩码。这个函数的作用是将标签转换为与预测值形状匹配的 one-hot 编码形式,以便进行后续的损失计算。
- 调用
- 如果预测值的维度与标签的维度相同:
- 计算有效掩码
valid_mask
,通过判断标签是否大于等于 0 且不等于忽略索引来确定哪些标签是有效的,并将其转换为浮点型。 - 如果权重存在,将权重与有效掩码相乘;如果权重不存在,则直接将有效掩码作为权重。
- 计算有效掩码
- 如果预测值的维度与标签的维度不同:
-
处理平均因子和非忽略目标平均:
- 如果
avg_factor
为None
且avg_non_ignore
为True
且reduction
为'mean'
,则将平均因子avg_factor
设置为有效掩码的总和(即非忽略目标的数量)。
- 如果
-
计算加权的元素级别的损失:
- 将权重转换为浮点型。
- 使用
F.binary_cross_entropy_with_logits
函数计算二值交叉熵损失,其中pred
是未经激活的预测值,label.float()
是浮点型的标签,pos_weight=class_weight
表示正类的权重,设置reduction='none'
表示不进行归约操作。
-
对加权损失进行归约:
- 使用
weight_reduce_loss
函数对加权损失进行归约操作,传入计算得到的损失、权重、归约方法和平均因子。
- 使用
-
返回损失值:
- 返回最终计算得到的损失值。
总的来说,这个函数提供了一种计算二值交叉熵损失的方法,可以根据不同的参数配置进行灵活的计算,包括处理不同维度的预测值和标签、指定损失归约方法、考虑样本权重和非忽略目标平均等情况。
mask_cross_entropy
函数:
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
这个函数用于计算掩码交叉熵损失。它接受预测值、目标值、标签以及一些参数,计算过程中会根据标签选择对应的预测值切片,然后使用F.binary_cross_entropy_with_logits
计算损失,并进行归约操作,最后返回计算得到的损失。
以下是对这段代码的详细解析:
函数功能:
这个函数用于计算掩码交叉熵损失(CrossEntropy loss for masks)。它接受预测值、目标值和标签作为输入,并返回计算得到的损失值。主要用于处理与掩码相关的任务,例如图像分割中的掩码预测。
参数解释:
pred
:torch.Tensor
类型,预测值张量,形状为(N, C, *)
,其中N
表示样本数量,C
是类别数量,*
表示任意数量的额外维度。target
:torch.Tensor
类型,预测的学习标签张量。label
:torch.Tensor
类型,指示掩码对应对象的类别标签。在掩码预测不是类别无关的情况下,用于选择属于该对象类别的掩码。reduction
:字符串类型,可选参数,用于指定损失的归约方法,这里只支持'mean'
,表示求平均。avg_factor
:整数类型,可选参数,用于平均损失的因子,这里默认为None
,但在函数内部有特定处理。class_weight
:列表类型,可选参数,表示每个类别的权重。ignore_index
:默认为None
,这里断言其为None
,因为这个损失函数不支持忽略索引。
函数执行过程:
-
断言和处理参数:
assert ignore_index is None, 'BCE loss does not support ignore_index'
,断言忽略索引必须为None
,因为这个损失函数不支持忽略特定索引的情况。assert reduction == 'mean' and avg_factor is None
,断言归约方法必须为'mean'
且平均因子必须为None
,这是对参数的限制。
-
选择特定的预测切片:
num_rois = pred.size()[0]
,获取预测值的第一维大小,即样本数量。inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
,创建一个从 0 到num_rois - 1
的整数序列张量,用于索引。pred_slice = pred[inds, label].squeeze(1)
,根据索引和标签选择预测值中的特定切片,并通过squeeze(1)
去除可能存在的维度为 1 的维度。
-
计算损失并返回:
- 使用
F.binary_cross_entropy_with_logits
函数计算二值交叉熵损失,输入为选择的预测切片pred_slice
和目标值target
,可以指定类别权重class_weight
,设置归约方法为'mean'
。 [None]
将损失值的形状从标量扩展为形状为(1,)
的一维张量并返回。
- 使用
总的来说,这个函数专门用于处理掩码相关任务中的交叉熵损失计算,具有特定的参数限制和计算逻辑。
三、定义 CrossEntropyLoss 类
@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
这个类继承自nn.Module
,是交叉熵损失的主要实现类。构造函数接受一些参数,用于配置损失函数的行为,如是否使用 sigmoid 激活、是否使用掩码、归约方法、类别权重、忽略索引、损失权重和是否仅对非忽略目标求平均等。在构造函数中,根据参数确定使用哪种具体的损失计算函数,并进行一些警告和初始化操作。
以下是对这段代码的详细解析:
类功能及装饰器说明:
CrossEntropyLoss
是一个自定义的交叉熵损失类,它继承自nn.Module
,表示这是一个可在 PyTorch 中作为神经网络模块使用的类。@MODELS.register_module()
是一个装饰器,用于将这个类注册到一个特定的模块注册表中,以便在其他地方可以方便地通过名称来调用和实例化这个类。
构造函数参数解释:
use_sigmoid
:布尔类型,可选参数,默认为False
。表示预测是否使用 sigmoid 激活函数而不是 softmax。如果为True
,则在计算损失时会有不同的处理。use_mask
:布尔类型,可选参数,默认为False
。表示是否使用掩码交叉熵损失。如果为True
,会采用特定的损失计算方式。reduction
:字符串类型,可选参数,默认为'mean'
。表示损失的归约方法,可以是'none'
(不进行归约)、'mean'
(求平均)或'sum'
(求和)。class_weight
:列表类型,可选参数,默认为None
。表示每个类别的权重。ignore_index
:整数或None
,可选参数,默认为None
。表示要忽略的标签索引。loss_weight
:浮点数类型,可选参数,默认为1.0
。表示损失的权重。avg_non_ignore
:布尔类型,默认为False
。表示损失是否只在非忽略的目标上进行平均。
构造函数执行过程:
-
调用父类构造函数:
super(CrossEntropyLoss, self).__init__()
调用父类nn.Module
的构造函数,进行必要的初始化操作。
-
断言和参数设置:
assert (use_sigmoid is False) or (use_mask is False)
确保use_sigmoid
和use_mask
不能同时为True
,因为这两种方式是互斥的选择。- 分别将传入的参数赋值给类的属性,如
self.use_sigmoid
、self.use_mask
、self.reduction
等。
-
警告提示:
- 如果
ignore_index
不为None
且avg_non_ignore
为False
且reduction
为'mean'
,则发出一个警告,提示用户如果想要忽略特定标签并在非忽略的标签上平均损失,可以将avg_non_ignore
设置为True
,这与 PyTorch 官方的交叉熵损失计算方式一致。
- 如果
-
根据参数选择损失计算准则:
- 如果
self.use_sigmoid
为True
,则将self.cls_criterion
设置为binary_cross_entropy
函数,表示使用二值交叉熵损失计算方式。 - 如果
self.use_mask
为True
,则将self.cls_criterion
设置为mask_cross_entropy
函数,表示使用掩码交叉熵损失计算方式。 - 如果都不是,则将
self.cls_criterion
设置为cross_entropy
函数,表示使用普通的交叉熵损失计算方式。
- 如果
总的来说,这个类提供了一种灵活的方式来配置和计算交叉熵损失,可以根据不同的需求选择不同的损失计算方式和参数设置。
def extra_repr(self):
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
这个方法用于提供额外的字符串表示,在打印对象时可以显示更多信息。
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=None,
**kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if ignore_index is None:
ignore_index = self.ignore_index
if self.class_weight is not None:
class_weight = cls_score.new_tensor(
self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls
这个方法是模型的前向传播方法。它接受预测值、标签以及一些参数,根据参数确定损失的归约方法和忽略索引,然后调用相应的损失计算函数,并乘以损失权重,最后返回计算得到的损失。
以下是对这段代码的详细解析:
函数功能:
这是一个类方法,用于计算交叉熵损失的前向传播过程。它接受预测值、真实标签以及一些可选参数,通过调用特定的损失计算准则(cls_criterion
)来计算损失,并根据配置进行加权和处理,最后返回计算得到的损失值。
参数解释:
cls_score
:torch.Tensor
类型,模型的预测结果。label
:torch.Tensor
类型,预测结果的真实学习标签。weight
:torch.Tensor
类型,可选参数,样本级别的损失权重。avg_factor
:整数类型,可选参数,用于平均损失的因子。reduction_override
:字符串类型,可选参数,用于覆盖默认的损失归约方法,可选值为None
、'none'
、'mean'
和'sum'
。ignore_index
:整数或None
,可选参数,要忽略的标签索引,如果不为None
,将覆盖默认的忽略索引。**kwargs
:额外的关键字参数,用于传递给损失计算准则。
函数执行过程:
-
参数检查和设置:
assert reduction_override in (None, 'none', 'mean', 'sum')
,确保reduction_override
参数的值是合法的。- 如果
reduction_override
不为None
,则使用它作为损失归约方法;否则,使用类的属性self.reduction
作为归约方法。 - 如果
ignore_index
为None
,则将其设置为类的属性self.ignore_index
。
-
处理类别权重:
- 如果类有定义的类别权重
self.class_weight
,则创建一个新的张量class_weight
,其值为类别权重,并放置在与cls_score
相同的设备上。 - 如果没有定义类别权重,则将
class_weight
设置为None
。
- 如果类有定义的类别权重
-
计算损失:
- 使用
self.cls_criterion
(根据类的配置确定的具体损失计算准则)来计算损失。传入预测值cls_score
、标签label
、权重weight
、类别权重class_weight
、归约方法reduction
、平均因子avg_factor
、忽略索引ignore_index
以及avg_non_ignore
属性和额外的关键字参数**kwargs
。 - 将计算得到的损失乘以类的属性
self.loss_weight
,得到最终的损失值loss_cls
。
- 使用
-
返回损失值:
- 返回计算得到的损失值
loss_cls
。
- 返回计算得到的损失值
总的来说,这个方法在交叉熵损失类中起到了核心的前向计算作用,通过灵活的参数配置和调用特定的损失计算准则,能够适应不同的任务需求。
四、定义 CrossEntropyCustomLoss 类
@MODELS.register_module()
class CrossEntropyCustomLoss(CrossEntropyLoss):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
num_classes=-1,
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyCustomLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
num_classes (int): Number of classes to classify.
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyCustomLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self.num_classes = num_classes
assert self.num_classes != -1
# custom output channels of the classifier
self.custom_cls_channels = True
# custom activation of cls_score
self.custom_activation = True
# custom accuracy of the classsifier
self.custom_accuracy = True
这个类继承自CrossEntropyLoss
,是一个自定义的交叉熵损失类。构造函数除了接受与CrossEntropyLoss
相同的参数外,还接受一个num_classes
参数,表示分类的类别数。在构造函数中,除了进行与父类相同的初始化操作外,还设置了一些自定义的属性,如custom_cls_channels
、custom_activation
和custom_accuracy
等。
以下是对这段代码的详细解析:
类功能及装饰器说明:
CrossEntropyCustomLoss
是一个自定义的交叉熵损失类,它继承自CrossEntropyLoss
类。和前面的CrossEntropyLoss
一样,它也被@MODELS.register_module()
装饰器注册到特定的模块注册表中,以便在其他地方可以方便地调用和实例化。
构造函数参数解释:
use_sigmoid
:布尔类型,可选参数,默认为False
。表示预测是否使用 sigmoid 激活函数。use_mask
:布尔类型,可选参数,默认为False
。表示是否使用掩码交叉熵损失。reduction
:字符串类型,可选参数,默认为'mean'
。表示损失的归约方法。num_classes
:整数类型,表示要分类的类别数量。class_weight
:列表类型,可选参数,默认为None
。表示每个类别的权重。ignore_index
:整数或None
,可选参数,默认为None
。表示要忽略的标签索引。loss_weight
:浮点数类型,可选参数,默认为1.0
。表示损失的权重。avg_non_ignore
:布尔类型,默认为False
。表示损失是否只在非忽略的目标上进行平均。
构造函数执行过程:
-
调用父类构造函数:
super(CrossEntropyCustomLoss, self).__init__()
调用父类CrossEntropyLoss
的构造函数,进行一些基本的初始化操作。
-
断言和参数设置:
- 与
CrossEntropyLoss
类似,确保use_sigmoid
和use_mask
不能同时为True
。 - 将传入的参数赋值给类的属性,如
self.use_sigmoid
、self.use_mask
、self.reduction
等。
- 与
-
警告提示:
- 如果
ignore_index
不为None
且avg_non_ignore
为False
且reduction
为'mean'
,发出警告提示用户可以设置avg_non_ignore
为True
以实现与 PyTorch 官方交叉熵损失相同的行为。
- 如果
-
根据参数选择损失计算准则:
- 与
CrossEntropyLoss
类似,根据use_sigmoid
和use_mask
的值选择合适的损失计算函数赋值给self.cls_criterion
。
- 与
-
额外的属性设置:
self.num_classes = num_classes
,将传入的类别数量赋值给类的属性。assert self.num_classes!= -1
,确保传入的类别数量不是 -1。- 设置三个自定义属性为
True
,分别表示分类器的自定义输出通道、自定义激活函数和自定义准确率计算。
总的来说,这个类在继承CrossEntropyLoss
的基础上,增加了对特定类别数量的要求,并设置了一些自定义的属性,用于在特定的场景下进行更灵活的交叉熵损失计算和处理。
def get_cls_channels(self, num_classes):
assert num_classes == self.num_classes
if not self.use_sigmoid:
return num_classes + 1
else:
return num_classes
这个方法用于获取分类器的输出通道数,根据是否使用 sigmoid 激活进行不同的计算。
def get_activation(self, cls_score):
fine_cls_score = cls_score[:, :self.num_classes]
if not self.use_sigmoid:
bg_score = cls_score[:, [-1]]
new_score = torch.cat([fine_cls_score, bg_score], dim=-1)
scores = F.softmax(new_score, dim=-1)
else:
score_classes = fine_cls_score.sigmoid()
score_neg = 1 - score_classes.sum(dim=1, keepdim=True)
score_neg = score_neg.clamp(min=0, max=1)
scores = torch.cat([score_classes, score_neg], dim=1)
return scores
这个方法用于获取分类器的激活函数输出,根据是否使用 sigmoid 激活进行不同的处理。
以下是对这段代码的详细解析:
方法功能:
这个方法用于获取分类器的激活输出。它根据不同的配置(是否使用 sigmoid)对输入的分类得分进行处理,得到最终的激活后的分数。
参数解释:
cls_score
:torch.Tensor
类型,分类器的输出得分张量。
方法执行过程:
-
提取精细分类得分:
fine_cls_score = cls_score[:, :self.num_classes]
,从输入的分类得分中提取出前self.num_classes
个类别对应的得分,即精细分类得分。
-
根据是否使用 sigmoid 进行不同处理:
-
如果不使用 sigmoid:
bg_score = cls_score[:, [-1]]
,获取背景类别的得分,这里使用了切片[-1]
来获取最后一个维度上的最后一个元素。new_score = torch.cat([fine_cls_score, bg_score], dim=-1)
,将精细分类得分和背景得分在最后一个维度上进行拼接。scores = F.softmax(new_score, dim=-1)
,对拼接后的得分进行 softmax 激活,得到最终的激活后的分数。
-
如果使用 sigmoid:
score_classes = fine_cls_score.sigmoid()
,对精细分类得分进行 sigmoid 激活,得到各个类别的激活得分。score_neg = 1 - score_classes.sum(dim=1, keepdim=True)
,计算除了正类之外的其他类别的总和的补集,即负类的得分。这里使用了sum
函数在维度 1 上进行求和,并通过keepdim=True
保持维度。score_neg = score_neg.clamp(min=0, max=1)
,将负类得分限制在 0 到 1 的范围内。scores = torch.cat([score_classes, score_neg], dim=1)
,将正类得分和负类得分在维度 1 上进行拼接,得到最终的激活后的分数。
-
-
返回激活后的分数:
- 返回计算得到的激活后的分数
scores
。
- 返回计算得到的激活后的分数
总的来说,这个方法提供了一种根据不同的配置来处理分类器输出得分并得到激活后分数的方式,可用于后续的计算和评估。
def get_accuracy(self, cls_score, labels):
fine_cls_score = cls_score[:, :self.num_classes]
pos_inds = labels < self.num_classes
acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds])
acc = dict()
acc['acc_classes'] = acc_classes
return acc
这个方法用于获取分类器的准确率,根据标签选择有效的预测值进行准确率计算。
以下是对这段代码的详细解析:
方法功能:
这个方法用于计算分类器的准确率。它根据输入的分类得分和真实标签,提取出正类的分类得分,并计算在这些正类上的准确率,最后以字典的形式返回准确率结果。
参数解释:
cls_score
:torch.Tensor
类型,分类器的输出得分张量。labels
:torch.Tensor
类型,真实的标签张量。
方法执行过程:
-
提取精细分类得分:
fine_cls_score = cls_score[:, :self.num_classes]
,从输入的分类得分中提取出前self.num_classes
个类别对应的得分,即精细分类得分。
-
确定正类索引:
pos_inds = labels < self.num_classes
,创建一个布尔张量,其中值为True
的位置表示对应的标签小于self.num_classes
,即属于正类。
-
计算正类准确率:
acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds])
,调用accuracy
函数(假设这是一个自定义的准确率计算函数),传入正类的分类得分(通过索引pos_inds
从fine_cls_score
中选择)和正类的标签(同样通过索引pos_inds
从labels
中选择),计算正类的准确率。
-
构建准确率字典并返回:
acc = dict()
创建一个空字典。acc['acc_classes'] = acc_classes
将正类准确率添加到字典中,键为'acc_classes'
。return acc
返回包含正类准确率的字典。
综上所述,CrossEntropyLoss
和CrossEntropyCustomLoss
提供了多种交叉熵损失的计算方式,可以根据不同的任务需求进行配置和使用。通过注册为 MMDetection 框架中的模块,可以方便地在模型中调用和配置这些损失函数。