〖MMDetection〗解析文件:mmdet/models/losses/focal_loss.py

《深入剖析 MMDetection 中的 Focal Loss 实现》

在目标检测领域,损失函数的设计对于模型的性能起着至关重要的作用。MMDetection 中的focal_loss.py文件实现了 Focal Loss,这是一种为解决类别不平衡问题而提出的强大损失函数。

原文詳細解析見:Loss:Focal Loss for Dense Object Detection

一、文件概述

mmdet/models/losses/focal_loss.py主要包含了一系列与 Focal Loss 相关的函数和类,旨在为目标检测任务提供有效的损失计算方法。

二、函数解析

1. py_sigmoid_focal_loss

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the
            number of classes
        target (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
    """
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    # Actually, pt here denotes (1 - pt) in the Focal Loss paper
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    # Thus it's pt.pow(gamma) rather than (1 - pt).pow(gamma)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    if weight is not None:
        if weight.shape != loss.shape:
            if weight.size(0) == loss.size(0):
                # For most cases, weight is of shape (num_priors, ),
                #  which means it does not have the second axis num_class
                weight = weight.view(-1, 1)
            else:
                # Sometimes, weight per anchor per class is also needed. e.g.
                #  in FSAF. But it may be flattened of shape
                #  (num_priors x num_class, ), while loss is still of shape
                #  (num_priors, num_class).
                assert weight.numel() == loss.numel()
                weight = weight.view(loss.size(0), -1)
        assert weight.ndim == loss.ndim
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss
1.1 方法功能

这个方法实现了 PyTorch 版本的焦点损失(Focal Loss),用于解决类别不平衡问题,尤其在目标检测等任务中,对难分类的样本给予更大的权重,提高模型对困难样本的学习能力。

1.2 参数解释
  • pred:预测值张量,形状为(N, C),其中N是样本数量,C是类别数量。预测值在经过sigmoid激活函数之前的值,表示每个样本属于每个类别的概率。
  • target:真实标签张量,与pred的类型相同。其形状可以是(N,)表示单个类别标签,也可以是(N, C)表示独热编码(one-hot)形式的标签。
  • weight:样本权重张量,可选参数。形状可以是(N,)(N, C),用于对不同的样本或样本的不同类别设置不同的权重。
  • gamma:焦点损失中的调制因子参数,控制难分类样本的权重增加程度。默认值为2.0
  • alpha:平衡因子,用于平衡正负样本的权重。默认值为0.25
  • reduction:损失的归约方式,可以是'mean'(求均值)、'sum'(求和)或'none'(不进行归约)。默认值为'mean'
  • avg_factor:用于平均损失的因子,当reduction'mean'时使用。
1.3 代码逐步解析
  1. 计算sigmoid后的预测值和标签

    • pred_sigmoid = pred.sigmoid():对预测值进行sigmoid激活,将其转换为概率值。
    • target = target.type_as(pred):确保标签张量的数据类型与预测值张量一致。
  2. 计算调制因子pt

    • pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target):这里的pt实际上在焦点损失论文中表示(1 - pt),其中pt是模型对当前样本预测正确的概率。如果样本属于正类,pt就是预测为正类的概率;如果样本属于负类,pt就是预测为负类的概率。
  3. 计算焦点权重focal_weight

    • focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma):根据平衡因子alpha、标签target和调制因子pt计算焦点权重。对于正样本,焦点权重会根据ptalpha进行调整,使得难分类的正样本(即pt较小的样本)获得更大的权重;对于负样本同理。
  4. 计算二元交叉熵损失并乘以焦点权重

    • loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weight:首先计算二元交叉熵损失,这里使用的是未经过sigmoid激活的预测值pred和真实标签target,因为F.binary_cross_entropy_with_logits函数内部会自动对预测值进行sigmoid激活并计算损失。然后将损失乘以焦点权重,得到最终的焦点损失。
  5. 应用样本权重

    • 如果weight不为None
      • 如果weight的形状与loss的形状不一致,进行形状调整:
        • 如果weight的第一个维度大小与loss的第一个维度大小相同,说明weight可能是形状为(num_priors,)的样本权重,没有类别维度。此时将其扩展为(num_priors, 1)的形状。
        • 如果weight的元素数量与loss的元素数量相同,说明weight可能是形状为(num_priors x num_class,)的权重,而loss是形状为(num_priors, num_class)的损失。在这种情况下,将weight调整为与loss相同的形状。
      • 确保weight的维度与loss的维度一致。
    • loss = weight_reduce_loss(loss, weight, reduction, avg_factor):使用weight_reduce_loss函数根据给定的权重weight、归约方式reduction和平均因子avg_factor对损失进行归约处理,得到最终的损失值。
  6. 返回损失值

    • 返回计算得到的焦点损失值。
1.4 結合原文講解:

(1)论文关系概述

py_sigmoid_focal_loss函数是基于论文《Focal Loss for Dense Object Detection》中提出的焦点损失(Focal Loss)实现的。焦点损失主要用于解决目标检测中正负样本类别不平衡的问题,通过对容易分类的样本降低权重,对难分类的样本增加权重,使得模型更加关注难分类的样本,从而提高检测精度。

(2)对应公式

  1. Sigmoid 函数

    • pred_sigmoid = pred.sigmoid(),将模型的预测结果pred通过 Sigmoid 函数转换为概率值。Sigmoid 函数公式为 σ ( x ) = 1 1 + e − x \sigma(x)=\frac{1}{1 + e^{-x}} σ(x)=1+ex1
  2. ** modulating factor(调制因子)计算**:

    • pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target),这里的pt实际上在论文中表示((1 - p_t)$,其中(p_t)是样本属于真实类别的概率。
    • focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma),这是焦点损失中的调制因子部分。其中 α \alpha α是平衡因子,用于平衡正负样本的权重; γ \gamma γ是聚焦参数,用于控制调制的程度。
  3. Focal Loss 公式

    • loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weight,这里结合了二元交叉熵损失(Binary Cross Entropy)和调制因子,得到最终的焦点损失。二元交叉熵损失公式为 L B C E = − y log ⁡ ( p ) − ( 1 − y ) log ⁡ ( 1 − p ) L_{BCE} = -y\log(p)-(1 - y)\log(1 - p) LBCE=ylog(p)(1y)log(1p),其中 y y y是真实标签, p p p是预测概率。焦点损失公式为 F L ( p t ) = − ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t)=-(1-p_t)^{\gamma}\log(p_t) FL(pt)=(1pt)γlog(pt),实际使用的是 α \alpha α平衡后的 Focal Loss: F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t) FL(pt)=αt(1pt)γlog(pt)

(2)举例解析过程

假设我们有一个二分类问题,预测结果pred [ 0.6 , 0.3 ] [0.6, 0.3] [0.6,0.3],真实标签target [ 1 , 0 ] [1, 0] [1,0] γ = 2 \gamma = 2 γ=2 α = 0.25 \alpha = 0.25 α=0.25

  1. 首先计算 Sigmoid 后的预测概率:

    • pred_sigmoid = [0.6224593314718246, 0.5744425168245771]
  2. 计算 p t p_t pt ( 1 − p t ) (1 - p_t) (1pt)

    • 对于第一个样本(正样本), p t = p r e d s i g m o i d [ 0 ] = 0.6224593314718246 p_t = pred_sigmoid[0]=0.6224593314718246 pt=predsigmoid[0]=0.6224593314718246,则 ( 1 − p t ) = 0.3775406685281754 (1 - p_t)=0.3775406685281754 (1pt)=0.3775406685281754
    • 对于第二个样本(负样本), p t = p r e d s i g m o i d [ 1 ] = 0.5744425168245771 p_t = pred_sigmoid[1]=0.5744425168245771 pt=predsigmoid[1]=0.5744425168245771,则 ( 1 − p t ) = 0.4255574831754229 (1 - p_t)=0.4255574831754229 (1pt)=0.4255574831754229
  3. 计算调制因子:

    • 对于第一个样本(正样本), p t = ( 1 − p r e d s i g m o i d [ 0 ] ) ∗ t a r g e t [ 0 ] + p r e d s i g m o i d [ 0 ] ∗ ( 1 − t a r g e t [ 0 ] ) = ( 1 − 0.6224593314718246 ) ∗ 1 + 0.6224593314718246 ∗ ( 1 − 1 ) = 0.3775406685281754 pt = (1 - pred_sigmoid[0]) * target[0] + pred_sigmoid[0] * (1 - target[0])=(1 - 0.6224593314718246) * 1 + 0.6224593314718246 * (1 - 1)=0.3775406685281754 pt=(1predsigmoid[0])target[0]+predsigmoid[0](1target[0])=(10.6224593314718246)1+0.6224593314718246(11)=0.3775406685281754
    • 对于第二个样本(负样本), p t = ( 1 − p r e d s i g m o i d [ 1 ] ) ∗ t a r g e t [ 1 ] + p r e d s i g m o i d [ 1 ] ∗ ( 1 − t a r g e t [ 1 ] ) = ( 1 − 0.5744425168245771 ) ∗ 0 + 0.5744425168245771 ∗ ( 1 − 0 ) = 0.4255574831754229 pt = (1 - pred_sigmoid[1]) * target[1] + pred_sigmoid[1] * (1 - target[1])=(1 - 0.5744425168245771) * 0 + 0.5744425168245771 * (1 - 0)=0.4255574831754229 pt=(1predsigmoid[1])target[1]+predsigmoid[1](1target[1])=(10.5744425168245771)0+0.5744425168245771(10)=0.4255574831754229
    • 对于第一个样本(正样本), f o c a l w e i g h t = ( a l p h a ∗ t a r g e t [ 0 ] + ( 1 − a l p h a ) ∗ ( 1 − t a r g e t [ 0 ] ) ) ∗ p t . p o w ( g a m m a ) = ( 0.25 ∗ 1 + ( 1 − 0.25 ) ∗ ( 1 − 1 ) ) ∗ 0.377540668528175 4 2 = 0.03522044134486913 focal_weight = (alpha * target[0] + (1 - alpha) * (1 - target[0])) * pt.pow(gamma)=(0.25 * 1 + (1 - 0.25) * (1 - 1)) * 0.3775406685281754^{2}=0.03522044134486913 focalweight=(alphatarget[0]+(1alpha)(1target[0]))pt.pow(gamma)=(0.251+(10.25)(11))0.37754066852817542=0.03522044134486913
    • 对于第二个样本(负样本), f o c a l w e i g h t = ( a l p h a ∗ t a r g e t [ 1 ] + ( 1 − a l p h a ) ∗ ( 1 − t a r g e t [ 1 ] ) ) ∗ p t . p o w ( g a m m a ) = ( 0.25 ∗ 0 + ( 1 − 0.25 ) ∗ ( 1 − 0 ) ) ∗ 0.425557483175422 9 2 = 0.14933689641952515 focal_weight = (alpha * target[1] + (1 - alpha) * (1 - target[1])) * pt.pow(gamma)=(0.25 * 0 + (1 - 0.25) * (1 - 0)) * 0.4255574831754229^{2}=0.14933689641952515 focalweight=(alphatarget[1]+(1alpha)(1target[1]))pt.pow(gamma)=(0.250+(10.25)(10))0.42555748317542292=0.14933689641952515
  4. 计算二元交叉熵损失:

    • 对于第一个样本(正样本), L B C E = − t a r g e t [ 0 ] log ⁡ ( p r e d s i g m o i d [ 0 ] ) − ( 1 − t a r g e t [ 0 ] ) log ⁡ ( 1 − p r e d s i g m o i d [ 0 ] ) = − 1 log ⁡ ( 0.6224593314718246 ) − ( 1 − 1 ) log ⁡ ( 1 − 0.6224593314718246 ) = − 0.4941490331089446 L_{BCE}=-target[0]\log(pred_sigmoid[0])-(1 - target[0])\log(1 - pred_sigmoid[0])=-1\log(0.6224593314718246)-(1 - 1)\log(1 - 0.6224593314718246)=-0.4941490331089446 LBCE=target[0]log(predsigmoid[0])(1target[0])log(1predsigmoid[0])=1log(0.6224593314718246)(11)log(10.6224593314718246)=0.4941490331089446
    • 对于第二个样本(负样本), L B C E = − t a r g e t [ 1 ] log ⁡ ( p r e d s i g m o i d [ 1 ] ) − ( 1 − t a r g e t [ 1 ] ) log ⁡ ( 1 − p r e d s i g m o i d [ 1 ] ) = − 0 log ⁡ ( 0.5744425168245771 ) − ( 1 − 0 ) log ⁡ ( 1 − 0.5744425168245771 ) = − 0.5596313035264794 L_{BCE}=-target[1]\log(pred_sigmoid[1])-(1 - target[1])\log(1 - pred_sigmoid[1])=-0\log(0.5744425168245771)-(1 - 0)\log(1 - 0.5744425168245771)=-0.5596313035264794 LBCE=target[1]log(predsigmoid[1])(1target[1])log(1predsigmoid[1])=0log(0.5744425168245771)(10)log(10.5744425168245771)=0.5596313035264794
  5. 计算焦点损失:

    • 对于第一个样本(正样本), l o s s = L B C E ∗ f o c a l w e i g h t = − 0.4941490331089446 ∗ 0.03522044134486913 = − 0.017411441774379015 loss = L_{BCE} * focal_weight=-0.4941490331089446 * 0.03522044134486913=-0.017411441774379015 loss=LBCEfocalweight=0.49414903310894460.03522044134486913=0.017411441774379015
    • 对于第二个样本(负样本), l o s s = L B C E ∗ f o c a l w e i g h t = − 0.5596313035264794 ∗ 0.14933689641952515 = − 0.08354377734184265 loss = L_{BCE} * focal_weight=-0.5596313035264794 * 0.14933689641952515=-0.08354377734184265 loss=LBCEfocalweight=0.55963130352647940.14933689641952515=0.08354377734184265

(4)与代码对应

  1. pred_sigmoid = pred.sigmoid():对应计算 Sigmoid 后的预测概率。
  2. target = target.type_as(pred):确保真实标签的数据类型与预测结果一致。
  3. pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target):计算 p t p_t pt ( 1 − p t ) (1 - p_t) (1pt)的中间值。
  4. focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma):计算调制因子。
  5. loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weight:结合二元交叉熵损失和调制因子计算焦点损失。
  6. if weight is not None:及后续代码:如果有样本权重,则根据权重对损失进行调整。
  7. loss = weight_reduce_loss(loss, weight, reduction, avg_factor):根据指定的 reduction 方法(如meansum等)和平均因子对损失进行进一步处理。

2. py_focal_loss_with_prob

def py_focal_loss_with_prob(pred,
                            target,
                            weight=None,
                            gamma=2.0,
                            alpha=0.25,
                            reduction='mean',
                            avg_factor=None):
    """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
    Different from `py_sigmoid_focal_loss`, this function accepts probability
    as input.

    Args:
        pred (torch.Tensor): The prediction probability with shape (N, C),
            C is the number of classes.
        target (torch.Tensor): The learning label of the prediction.
            The target shape support (N,C) or (N,), (N,C) means one-hot form.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
    """
    if pred.dim() != target.dim():
        num_classes = pred.size(1)
        target = F.one_hot(target, num_classes=num_classes + 1)
        target = target[:, :num_classes]

    target = target.type_as(pred)
    pt = (1 - pred) * target + pred * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy(
        pred, target, reduction='none') * focal_weight
    if weight is not None:
        if weight.shape != loss.shape:
            if weight.size(0) == loss.size(0):
                # For most cases, weight is of shape (num_priors, ),
                #  which means it does not have the second axis num_class
                weight = weight.view(-1, 1)
            else:
                # Sometimes, weight per anchor per class is also needed. e.g.
                #  in FSAF. But it may be flattened of shape
                #  (num_priors x num_class, ), while loss is still of shape
                #  (num_priors, num_class).
                assert weight.numel() == loss.numel()
                weight = weight.view(loss.size(0), -1)
        assert weight.ndim == loss.ndim
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss
2.1 函数与论文关系概述

py_focal_loss_with_prob函数同样是基于论文《Focal Loss for Dense Object Detection》中提出的焦点损失(Focal Loss)实现的。与py_sigmoid_focal_loss不同的是,这个函数接受概率作为输入。

2.2 对应公式
  1. ** modulating factor(调制因子)计算**:

    • pt = (1 - pred) * target + pred * (1 - target),这里的pt与论文中的含义一致,只不过这里直接使用输入的概率pred进行计算。
    • focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma),同样是焦点损失中的调制因子部分,其中(\alpha)是平衡因子,(\gamma)是聚焦参数。
  2. Focal Loss 公式

    • loss = F.binary_cross_entropy(pred, target, reduction='none') * focal_weight,这里结合了二元交叉熵损失(Binary Cross Entropy)和调制因子,得到最终的焦点损失。二元交叉熵损失公式为(L_{BCE} = -y\log§-(1 - y)\log(1 - p)),其中(y)是真实标签,§是预测概率。焦点损失公式为(FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t))。
2.3 举例解析过程

假设我们有一个三分类问题,预测概率pred为([0.7, 0.2, 0.1]),真实标签target为([1, 0, 0])(以 one-hot 形式表示),(\gamma = 2),(\alpha = 0.25)。

  1. 计算(p_t)和((1 - p_t)):

    • 对于第一个类别(真实类别),(p_t = pred[0]=0.7),则((1 - p_t)=0.3)。
    • 对于第二个类别,(p_t = pred[1]=0.2),则((1 - p_t)=0.8)。
    • 对于第三个类别,(p_t = pred[2]=0.1),则((1 - p_t)=0.9)。
  2. 计算调制因子:

    • 对于第一个类别(真实类别),(pt = (1 - pred[0]) * target[0] + pred[0] * (1 - target[0])=(1 - 0.7) * 1 + 0.7 * (1 - 1)=0.3)。
    • 对于第二个类别,(pt = (1 - pred[1]) * target[1] + pred[1] * (1 - target[1])=(1 - 0.2) * 0 + 0.2 * (1 - 0)=0.2)。
    • 对于第三个类别,(pt = (1 - pred[2]) * target[2] + pred[2] * (1 - target[2])=(1 - 0.1) * 0 + 0.1 * (1 - 0)=0.1)。
    • 对于第一个类别(真实类别),(focal_weight = (alpha * target[0] + (1 - alpha) * (1 - target[0])) * pt.pow(gamma)=(0.25 * 1 + (1 - 0.25) * (1 - 1)) * 0.3^{2}=0.0225)。
    • 对于第二个类别,(focal_weight = (alpha * target[1] + (1 - alpha) * (1 - target[1])) * pt.pow(gamma)=(0.25 * 0 + (1 - 0.25) * (1 - 0)) * 0.2^{2}=0.03)。
    • 对于第三个类别,(focal_weight = (alpha * target[2] + (1 - alpha) * (1 - target[2])) * pt.pow(gamma)=(0.25 * 0 + (1 - 0.25) * (1 - 0)) * 0.1^{2}=0.0075)。
  3. 计算二元交叉熵损失:

    • 对于第一个类别(真实类别),(L_{BCE}=-target[0]\log(pred[0])-(1 - target[0])\log(1 - pred[0])=-1\log(0.7)-(1 - 1)\log(1 - 0.7)=-0.35667494393873245)。
    • 对于第二个类别,(L_{BCE}=-target[1]\log(pred[1])-(1 - target[1])\log(1 - pred[1])=-0\log(0.2)-(1 - 0)\log(1 - 0.2)=-0.22314355131420976)。
    • 对于第三个类别,(L_{BCE}=-target[2]\log(pred[2])-(1 - target[2])\log(1 - pred[2])=-0\log(0.1)-(1 - 0)\log(1 - 0.1)=-0.10536051565782628)。
  4. 计算焦点损失:

    • 对于第一个类别(真实类别),(loss = L_{BCE} * focal_weight=-0.35667494393873245 * 0.0225=-0.008025186238621759)。
    • 对于第二个类别,(loss = L_{BCE} * focal_weight=-0.22314355131420976 * 0.03=-0.006694306539426293)。
    • 对于第三个类别,(loss = L_{BCE} * focal_weight=-0.10536051565782628 * 0.0075=-0.0007902038674336971)。
2.4 与代码对应
  1. if pred.dim()!= target.dim()::如果预测概率和真实标签的维度不一致,说明真实标签不是 one-hot 形式,需要进行转换。
    • num_classes = pred.size(1):获取预测的类别数。
    • target = F.one_hot(target, num_classes=num_classes + 1):将真实标签转换为 one-hot 形式,这里的num_classes + 1是为了处理可能的背景类。
    • target = target[:, :num_classes]:只取前num_classes个维度,去除可能的背景类。
  2. target = target.type_as(pred):确保真实标签的数据类型与预测结果一致。
  3. pt = (1 - pred) * target + pred * (1 - target):计算(p_t)和((1 - p_t))的中间值。
  4. focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma):计算调制因子。
  5. loss = F.binary_cross_entropy(pred, target, reduction='none') * focal_weight:结合二元交叉熵损失和调制因子计算焦点损失。
  6. if weight is not None:及后续代码:如果有样本权重,则根据权重对损失进行调整。
  7. loss = weight_reduce_loss(loss, weight, reduction, avg_factor):根据指定的 reduction 方法(如meansum等)和平均因子对损失进行进一步处理。

3. sigmoid_focal_loss

def sigmoid_focal_loss(pred,
                       target,
                       weight=None,
                       gamma=2.0,
                       alpha=0.25,
                       reduction='mean',
                       avg_factor=None):
    r"""A wrapper of cuda version `Focal Loss
    <https://arxiv.org/abs/1708.02002>`_.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        target (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
    """
    # Function.apply does not accept keyword arguments, so the decorator
    # "weighted_loss" is not applicable
    loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
                               alpha, None, 'none')
    if weight is not None:
        if weight.shape != loss.shape:
            if weight.size(0) == loss.size(0):
                # For most cases, weight is of shape (num_priors, ),
                #  which means it does not have the second axis num_class
                weight = weight.view(-1, 1)
            else:
                # Sometimes, weight per anchor per class is also needed. e.g.
                #  in FSAF. But it may be flattened of shape
                #  (num_priors x num_class, ), while loss is still of shape
                #  (num_priors, num_class).
                assert weight.numel() == loss.numel()
                weight = weight.view(loss.size(0), -1)
        assert weight.ndim == loss.ndim
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss
3.1 函数与论文关系概述

sigmoid_focal_loss函数是对焦点损失(Focal Loss)的一种实现,与论文《Focal Loss for Dense Object Detection》紧密相关。该函数旨在解决目标检测中正负样本类别不平衡的问题,通过引入调制因子,使模型更加关注难分类的样本。

3.2 对应公式
  1. ** modulating factor(调制因子)计算**:

    • 与前面两个函数类似,这里也涉及到调制因子的计算。公式中的(\gamma)是聚焦参数,用于控制调制的程度;(\alpha)是平衡因子,用于平衡正负样本的权重。
    • 调制因子的计算公式为(focal_weight=(alpha * target+(1 - alpha)*(1 - target))*pt.pow(gamma)),其中(pt)的计算方式与前面的函数类似。
  2. Focal Loss 公式

    • 该函数最终的损失计算也是结合了二元交叉熵损失和调制因子。焦点损失公式为(FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t))。
3.3 举例解析过程

假设我们有一个二分类问题,预测结果pred为([0.6, 0.3]),真实标签target为([1, 0]),(\gamma = 2),(\alpha = 0.25)。

  1. 首先调用_sigmoid_focal_loss函数计算初始损失:

    • loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, 'none')。这里假设_sigmoid_focal_loss函数内部实现了与前面函数类似的计算逻辑,得到初始的焦点损失值,假设这个值为([0.1, 0.05])。
  2. 处理权重:

    • 如果有样本权重weight,则需要根据权重的形状进行调整。
    • 如果weight.size(0) == loss.size(0),说明权重的形状是((num_priors,)),没有类别维度,此时将权重调整为((num_priors, 1))的形状,即weight = weight.view(-1, 1)
    • 如果权重的形状与损失的形状不同,但元素数量相等,说明可能是扁平的形状,需要根据损失的形状进行调整,即weight = weight.view(loss.size(0), -1)
    • 确保权重的维度与损失的维度一致,即assert weight.ndim == loss.ndim
  3. 计算最终损失:

    • loss = weight_reduce_loss(loss, weight, reduction, avg_factor)。这里根据指定的reduction方法(如meansum等)和平均因子avg_factor对损失进行进一步处理,得到最终的损失值。
3.3 与代码对应
  1. loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, 'none'):调用底层的 Cuda 版本的焦点损失计算函数,得到初始的损失值。
  2. if weight is not None:及后续代码:处理样本权重,如果有权重,则根据权重的形状进行调整,并确保权重的维度与损失的维度一致。
  3. loss = weight_reduce_loss(loss, weight, reduction, avg_factor):根据指定的reduction方法和平均因子对损失进行进一步处理,得到最终的损失值。

4. FocalLoss类的相关方法

4.1 源碼
@MODELS.register_module()
class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0,
                 activated=False):
        """`Focal Loss <https://arxiv.org/abs/1708.02002>`_

        Args:
            use_sigmoid (bool, optional): Whether to the prediction is
                used for sigmoid or softmax. Defaults to True.
            gamma (float, optional): The gamma for calculating the modulating
                factor. Defaults to 2.0.
            alpha (float, optional): A balanced form for Focal Loss.
                Defaults to 0.25.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'. Options are "none", "mean" and
                "sum".
            loss_weight (float, optional): Weight of loss. Defaults to 1.0.
            activated (bool, optional): Whether the input is activated.
                If True, it means the input has been activated and can be
                treated as probabilities. Else, it should be treated as logits.
                Defaults to False.
        """
        super(FocalLoss, self).__init__()
        assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
        self.use_sigmoid = use_sigmoid
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.activated = activated

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning label of the prediction.
                The target shape support (N,C) or (N,), (N,C) means
                one-hot form.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Options are "none", "mean" and "sum".

        Returns:
            torch.Tensor: The calculated loss
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if self.use_sigmoid:
            if self.activated:
                calculate_loss_func = py_focal_loss_with_prob
            else:
                if pred.dim() == target.dim():
                    # this means that target is already in One-Hot form.
                    calculate_loss_func = py_sigmoid_focal_loss
                elif torch.cuda.is_available() and pred.is_cuda:
                    calculate_loss_func = sigmoid_focal_loss
                else:
                    num_classes = pred.size(1)
                    target = F.one_hot(target, num_classes=num_classes + 1)
                    target = target[:, :num_classes]
                    calculate_loss_func = py_sigmoid_focal_loss

            loss_cls = self.loss_weight * calculate_loss_func(
                pred,
                target,
                weight,
                gamma=self.gamma,
                alpha=self.alpha,
                reduction=reduction,
                avg_factor=avg_factor)

        else:
            raise NotImplementedError
        return loss_cls
4.2 解析
  1. 方法概述

forward方法是FocalLoss类的核心方法,用于计算焦点损失(Focal Loss)。它接受预测值pred、目标值target以及一些可选参数,如权重weight、平均因子avg_factor和损失缩减方式覆盖参数reduction_override,并返回计算得到的损失张量。

  1. 参数解释

1). pred:预测值张量,通常是模型对某个样本的预测结果,可能是一个概率值或未经过激活的原始输出。

2). target:目标值张量,代表预测值对应的真实标签,可以是 one-hot 编码形式或其他形式。

3). weight(可选):损失权重张量,用于对每个预测的损失进行加权。默认值为None

4). avg_factor(可选):平均因子,用于平均损失。默认值为None

5). reduction_override(可选):用于覆盖默认损失缩减方式的参数。可以是"none""mean""sum"。默认值为None

  1. 计算过程

1). 检查损失缩减方式的合法性:

  • assert reduction_override in (None, 'none', 'mean', 'sum'):确保reduction_override参数的值在合法的取值范围内。

2). 确定损失缩减方式:

  • reduction = reduction_override if reduction_override else self.reduction:如果reduction_override不为None,则使用该参数指定的缩减方式;否则,使用在类初始化时设置的缩减方式self.reduction

3). 根据条件选择损失计算函数:

  • 如果self.use_sigmoidTrue(在类初始化时已断言只能是这种情况):
    • 如果self.activatedTrue,表示输入已经是激活后的概率形式,此时将py_focal_loss_with_prob赋值给calculate_loss_func
    • 如果输入未激活且预测值和目标值维度相同,说明目标值已经是 one-hot 编码形式,将py_sigmoid_focal_loss赋值给calculate_loss_func
    • 如果输入未激活且有 GPU 可用且预测值在 GPU 上,将sigmoid_focal_loss赋值给calculate_loss_func
    • 如果上述条件都不满足,将目标值转换为 one-hot 编码形式后,将py_sigmoid_focal_loss赋值给calculate_loss_func
  • 如果self.use_sigmoidFalse,则抛出NotImplementedError,因为目前只支持使用sigmoid的焦点损失。

4). 计算损失:

  • loss_cls = self.loss_weight * calculate_loss_func(pred, target, weight, gamma=self.gamma, alpha=self.alpha, reduction=reduction, avg_factor=avg_factor):使用选择的损失计算函数计算损失,并乘以在类初始化时设置的损失权重self.loss_weight,得到最终的损失值。

5). 返回损失值:

  • return loss_cls:返回计算得到的损失张量。

5. FocalCustomLoss类的相关方法

@MODELS.register_module()
class FocalCustomLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 num_classes=-1,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0,
                 activated=False):
        """`Focal Loss for V3Det <https://arxiv.org/abs/1708.02002>`_

        Args:
            use_sigmoid (bool, optional): Whether to the prediction is
                used for sigmoid or softmax. Defaults to True.
            num_classes (int): Number of classes to classify.
            gamma (float, optional): The gamma for calculating the modulating
                factor. Defaults to 2.0.
            alpha (float, optional): A balanced form for Focal Loss.
                Defaults to 0.25.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'. Options are "none", "mean" and
                "sum".
            loss_weight (float, optional): Weight of loss. Defaults to 1.0.
            activated (bool, optional): Whether the input is activated.
                If True, it means the input has been activated and can be
                treated as probabilities. Else, it should be treated as logits.
                Defaults to False.
        """
        super(FocalCustomLoss, self).__init__()
        assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
        self.use_sigmoid = use_sigmoid
        self.num_classes = num_classes
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.activated = activated

        assert self.num_classes != -1

        # custom output channels of the classifier
        self.custom_cls_channels = True
        # custom activation of cls_score
        self.custom_activation = True
        # custom accuracy of the classsifier
        self.custom_accuracy = True

    def get_cls_channels(self, num_classes):
        assert num_classes == self.num_classes
        return num_classes

    def get_activation(self, cls_score):

        fine_cls_score = cls_score[:, :self.num_classes]

        score_classes = fine_cls_score.sigmoid()

        return score_classes

    def get_accuracy(self, cls_score, labels):

        fine_cls_score = cls_score[:, :self.num_classes]

        pos_inds = labels < self.num_classes
        acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds])
        acc = dict()
        acc['acc_classes'] = acc_classes
        return acc

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Options are "none", "mean" and "sum".

        Returns:
            torch.Tensor: The calculated loss
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if self.use_sigmoid:

            num_classes = pred.size(1)
            target = F.one_hot(target, num_classes=num_classes + 1)
            target = target[:, :num_classes]
            calculate_loss_func = py_sigmoid_focal_loss

            loss_cls = self.loss_weight * calculate_loss_func(
                pred,
                target,
                weight,
                gamma=self.gamma,
                alpha=self.alpha,
                reduction=reduction,
                avg_factor=avg_factor)

        else:
            raise NotImplementedError
        return loss_cls
5.1 整体概述

FocalCustomLoss是一个基于 PyTorch 的神经网络模块,它实现了一种自定义的焦点损失(Focal Loss),用于特定的目标检测任务。这个类继承自nn.Module,可以作为神经网络中的一个损失层来使用。它的设计主要基于论文《Focal Loss for Dense Object Detection》,旨在解决目标检测中的类别不平衡问题。

5.2 构造函数(__init__方法)
  1. 参数解释

    • use_sigmoid=True:表示是否使用 Sigmoid 函数对预测结果进行处理。默认值为True,目前只支持使用 Sigmoid 焦点损失。
    • num_classes=-1:要分类的类别数量。这个参数必须在实例化该类时指定一个非负整数。
    • gamma=2.0:用于计算调制因子的聚焦参数。较高的值会使模型更加关注难分类的样本。
    • alpha=0.25:平衡因子,用于平衡正负样本的权重。
    • reduction='mean':损失的归约方法,可以是'none''mean''sum'。默认值为'mean',表示对损失进行平均。
    • loss_weight=1.0:损失的权重。可以用来调整该损失在总损失中的贡献。
    • activated=False:如果为True,表示输入已经经过激活,可以被视为概率;如果为False,则输入应被视为对数几率。
  2. 断言与自定义属性设置

    • assert use_sigmoid is True:确保只支持使用 Sigmoid 焦点损失。
    • assert self.num_classes!= -1:确保在实例化该类时指定了类别数量。
    • 设置了一些自定义属性:
      • self.custom_cls_channels = True:表示分类器的输出通道是自定义的。
      • self.custom_activation = True:表示分类器的激活是自定义的。
      • self.custom_accuracy = True:表示分类器的准确率计算是自定义的。
5.3 方法解析
  1. get_cls_channels方法:

    • 作用:获取分类器的输出通道数。
    • 实现:接收一个参数num_classes,并断言该参数与实例化时指定的类别数量相等,然后返回该类别数量。
  2. get_activation方法:

    • 作用:对分类器的输出进行激活。
    • 实现:接收一个参数cls_score,表示分类器的输出分数。首先取出前self.num_classes个类别对应的分数fine_cls_score,然后对其进行 Sigmoid 激活,得到概率值score_classes并返回。
  3. get_accuracy方法:

    • 作用:计算分类器的准确率。
    • 实现:接收两个参数cls_scorelabels,分别表示分类器的输出分数和真实标签。首先取出前self.num_classes个类别对应的分数fine_cls_score,然后找出真实标签小于类别数量的样本索引pos_inds,对于这些样本计算准确率acc_classes,并将其封装在一个字典中返回。
  4. forward方法:

    • 作用:前向传播函数,计算焦点损失。
    • 参数解释:
      • pred:模型的预测结果,是一个 PyTorch 张量。
      • target:真实标签,也是一个 PyTorch 张量。
      • weight:每个样本的损失权重,可选参数。
      • avg_factor:用于平均损失的因子,可选参数。
      • reduction_override:用于覆盖默认的损失归约方法的参数,可选。
    • 实现过程:
      • 首先检查reduction_override参数是否合法。
      • 如果self.use_sigmoidTrue(目前只支持这种情况):
        • 计算预测结果的类别数量num_classes
        • 将真实标签转换为 one-hot 编码形式target
        • 设置计算损失的函数为py_sigmoid_focal_loss
        • 计算焦点损失loss_cls,将损失权重、预测结果、真实标签、聚焦参数、平衡因子、归约方法和平均因子传递给calculate_loss_func
      • 如果self.use_sigmoidFalse,则抛出NotImplementedError,因为目前只支持 Sigmoid 焦点损失。
      • 最后返回计算得到的损失loss_cls
5.4 结合论文分析

这个类的实现与论文《Focal Loss for Dense Object Detection》中的焦点损失概念紧密相关。通过设置聚焦参数gamma和平衡因子alpha,可以调整损失函数对易分类和难分类样本的关注程度,从而有效地解决目标检测中的类别不平衡问题。同时,通过自定义的准确率计算方法和激活函数,可以更好地适应特定的目标检测任务。

5.5 举例说明

假设我们有一个三分类问题,预测结果pred为([[0.7, 0.2, 0.1], [0.3, 0.4, 0.3]]),真实标签target为([0, 1]),类别数量num_classes = 3,聚焦参数gamma = 2,平衡因子alpha = 0.25,损失归约方法为'mean',损失权重loss_weight = 1.0

  1. 实例化FocalCustomLoss类:

    • loss_fn = FocalCustomLoss(num_classes=3, gamma=2, alpha=0.25)
  2. 调用forward方法计算损失:

    • loss = loss_fn(pred, target)
    • forward方法中,首先将真实标签转换为 one-hot 编码形式,即target = F.one_hot(target, num_classes=num_classes + 1),得到([[1, 0, 0], [0, 1, 0]]),然后取前三个类别,得到([[1, 0, 0], [0, 1, 0]]`。
    • 设置计算损失的函数为py_sigmoid_focal_loss
    • 调用calculate_loss_func计算焦点损失,假设得到的损失值为([0.1, 0.2])。
    • 由于损失归约方法为'mean',所以最终的损失为((0.1 + 0.2) / 2 = 0.15`。

通过这个例子,可以看到FocalCustomLoss类如何根据输入的预测结果和真实标签计算焦点损失,并根据指定的参数进行归约。

三、结语

mmdet/models/losses/focal_loss.py文件提供了多种实现 Focal Loss 的方式,包括纯 PyTorch 实现、Cuda 加速实现以及自定义的损失类。这些实现方式为目标检测任务中的损失计算提供了灵活的选择,用户可以根据自己的需求选择合适的方式来计算 Focal Loss。同时,通过自定义的损失类,用户还可以根据特定的任务需求进行进一步的扩展和定制。Focal Loss 的引入有助于解决目标检测中的类别不平衡问题,提高模型的性能和泛化能力。

Traceback (most recent call last): File "/home/bder73002/hpy/ConvNextV2_Demo/train+.py", line 275, in <module> train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema) File "/home/bder73002/hpy/ConvNextV2_Demo/train+.py", line 48, in train loss = torch.nan_to_num(criterion_train(output, targets)) # 计算loss File "/home/bder73002/anaconda3/envs/python3.9.2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/home/bder73002/hpy/ConvNextV2_Demo/models/losses.py", line 56, in forward focal_loss = self.focal_loss(x, target) File "/home/bder73002/anaconda3/envs/python3.9.2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/home/bder73002/hpy/ConvNextV2_Demo/models/losses.py", line 21, in forward return focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), self.gamma) File "/home/bder73002/anaconda3/envs/python3.9.2/lib/python3.9/site-packages/torch/nn/functional.py", line 2693, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "/home/bder73002/anaconda3/envs/python3.9.2/lib/python3.9/site-packages/torch/nn/functional.py", line 2388, in nll_loss ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
05-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值