目录
《深入剖析 MMDetection 中的 Focal Loss 实现》
在目标检测领域,损失函数的设计对于模型的性能起着至关重要的作用。MMDetection 中的focal_loss.py
文件实现了 Focal Loss,这是一种为解决类别不平衡问题而提出的强大损失函数。
原文詳細解析見:Loss:Focal Loss for Dense Object Detection
一、文件概述
mmdet/models/losses/focal_loss.py
主要包含了一系列与 Focal Loss 相关的函数和类,旨在为目标检测任务提供有效的损失计算方法。
二、函数解析
1. py_sigmoid_focal_loss
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
# Actually, pt here denotes (1 - pt) in the Focal Loss paper
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
# Thus it's pt.pow(gamma) rather than (1 - pt).pow(gamma)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
1.1 方法功能
这个方法实现了 PyTorch 版本的焦点损失(Focal Loss),用于解决类别不平衡问题,尤其在目标检测等任务中,对难分类的样本给予更大的权重,提高模型对困难样本的学习能力。
1.2 参数解释
pred
:预测值张量,形状为(N, C)
,其中N
是样本数量,C
是类别数量。预测值在经过sigmoid
激活函数之前的值,表示每个样本属于每个类别的概率。target
:真实标签张量,与pred
的类型相同。其形状可以是(N,)
表示单个类别标签,也可以是(N, C)
表示独热编码(one-hot)形式的标签。weight
:样本权重张量,可选参数。形状可以是(N,)
或(N, C)
,用于对不同的样本或样本的不同类别设置不同的权重。gamma
:焦点损失中的调制因子参数,控制难分类样本的权重增加程度。默认值为2.0
。alpha
:平衡因子,用于平衡正负样本的权重。默认值为0.25
。reduction
:损失的归约方式,可以是'mean'
(求均值)、'sum'
(求和)或'none'
(不进行归约)。默认值为'mean'
。avg_factor
:用于平均损失的因子,当reduction
为'mean'
时使用。
1.3 代码逐步解析
-
计算
sigmoid
后的预测值和标签pred_sigmoid = pred.sigmoid()
:对预测值进行sigmoid
激活,将其转换为概率值。target = target.type_as(pred)
:确保标签张量的数据类型与预测值张量一致。
-
计算调制因子
pt
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
:这里的pt
实际上在焦点损失论文中表示(1 - pt)
,其中pt
是模型对当前样本预测正确的概率。如果样本属于正类,pt
就是预测为正类的概率;如果样本属于负类,pt
就是预测为负类的概率。
-
计算焦点权重
focal_weight
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
:根据平衡因子alpha
、标签target
和调制因子pt
计算焦点权重。对于正样本,焦点权重会根据pt
和alpha
进行调整,使得难分类的正样本(即pt
较小的样本)获得更大的权重;对于负样本同理。
-
计算二元交叉熵损失并乘以焦点权重
loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weight
:首先计算二元交叉熵损失,这里使用的是未经过sigmoid
激活的预测值pred
和真实标签target
,因为F.binary_cross_entropy_with_logits
函数内部会自动对预测值进行sigmoid
激活并计算损失。然后将损失乘以焦点权重,得到最终的焦点损失。
-
应用样本权重
- 如果
weight
不为None
:- 如果
weight
的形状与loss
的形状不一致,进行形状调整:- 如果
weight
的第一个维度大小与loss
的第一个维度大小相同,说明weight
可能是形状为(num_priors,)
的样本权重,没有类别维度。此时将其扩展为(num_priors, 1)
的形状。 - 如果
weight
的元素数量与loss
的元素数量相同,说明weight
可能是形状为(num_priors x num_class,)
的权重,而loss
是形状为(num_priors, num_class)
的损失。在这种情况下,将weight
调整为与loss
相同的形状。
- 如果
- 确保
weight
的维度与loss
的维度一致。
- 如果
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
:使用weight_reduce_loss
函数根据给定的权重weight
、归约方式reduction
和平均因子avg_factor
对损失进行归约处理,得到最终的损失值。
- 如果
-
返回损失值
- 返回计算得到的焦点损失值。
1.4 結合原文講解:
(1)论文关系概述
py_sigmoid_focal_loss
函数是基于论文《Focal Loss for Dense Object Detection》中提出的焦点损失(Focal Loss)实现的。焦点损失主要用于解决目标检测中正负样本类别不平衡的问题,通过对容易分类的样本降低权重,对难分类的样本增加权重,使得模型更加关注难分类的样本,从而提高检测精度。
(2)对应公式
-
Sigmoid 函数:
pred_sigmoid = pred.sigmoid()
,将模型的预测结果pred
通过 Sigmoid 函数转换为概率值。Sigmoid 函数公式为 σ ( x ) = 1 1 + e − x \sigma(x)=\frac{1}{1 + e^{-x}} σ(x)=1+e−x1。
-
** modulating factor(调制因子)计算**:
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
,这里的pt
实际上在论文中表示((1 - p_t)$,其中(p_t)是样本属于真实类别的概率。focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
,这是焦点损失中的调制因子部分。其中 α \alpha α是平衡因子,用于平衡正负样本的权重; γ \gamma γ是聚焦参数,用于控制调制的程度。
-
Focal Loss 公式:
loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weight
,这里结合了二元交叉熵损失(Binary Cross Entropy)和调制因子,得到最终的焦点损失。二元交叉熵损失公式为 L B C E = − y log ( p ) − ( 1 − y ) log ( 1 − p ) L_{BCE} = -y\log(p)-(1 - y)\log(1 - p) LBCE=−ylog(p)−(1−y)log(1−p),其中 y y y是真实标签, p p p是预测概率。焦点损失公式为 F L ( p t ) = − ( 1 − p t ) γ log ( p t ) FL(p_t)=-(1-p_t)^{\gamma}\log(p_t) FL(pt)=−(1−pt)γlog(pt),实际使用的是 α \alpha α平衡后的 Focal Loss: F L ( p t ) = − α t ( 1 − p t ) γ log ( p t ) FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t) FL(pt)=−αt(1−pt)γlog(pt)。
(2)举例解析过程
假设我们有一个二分类问题,预测结果pred
为
[
0.6
,
0.3
]
[0.6, 0.3]
[0.6,0.3],真实标签target
为
[
1
,
0
]
[1, 0]
[1,0],
γ
=
2
\gamma = 2
γ=2,
α
=
0.25
\alpha = 0.25
α=0.25。
-
首先计算 Sigmoid 后的预测概率:
pred_sigmoid = [0.6224593314718246, 0.5744425168245771]
。
-
计算 p t p_t pt和 ( 1 − p t ) (1 - p_t) (1−pt):
- 对于第一个样本(正样本), p t = p r e d s i g m o i d [ 0 ] = 0.6224593314718246 p_t = pred_sigmoid[0]=0.6224593314718246 pt=predsigmoid[0]=0.6224593314718246,则 ( 1 − p t ) = 0.3775406685281754 (1 - p_t)=0.3775406685281754 (1−pt)=0.3775406685281754。
- 对于第二个样本(负样本), p t = p r e d s i g m o i d [ 1 ] = 0.5744425168245771 p_t = pred_sigmoid[1]=0.5744425168245771 pt=predsigmoid[1]=0.5744425168245771,则 ( 1 − p t ) = 0.4255574831754229 (1 - p_t)=0.4255574831754229 (1−pt)=0.4255574831754229。
-
计算调制因子:
- 对于第一个样本(正样本), p t = ( 1 − p r e d s i g m o i d [ 0 ] ) ∗ t a r g e t [ 0 ] + p r e d s i g m o i d [ 0 ] ∗ ( 1 − t a r g e t [ 0 ] ) = ( 1 − 0.6224593314718246 ) ∗ 1 + 0.6224593314718246 ∗ ( 1 − 1 ) = 0.3775406685281754 pt = (1 - pred_sigmoid[0]) * target[0] + pred_sigmoid[0] * (1 - target[0])=(1 - 0.6224593314718246) * 1 + 0.6224593314718246 * (1 - 1)=0.3775406685281754 pt=(1−predsigmoid[0])∗target[0]+predsigmoid[0]∗(1−target[0])=(1−0.6224593314718246)∗1+0.6224593314718246∗(1−1)=0.3775406685281754。
- 对于第二个样本(负样本), p t = ( 1 − p r e d s i g m o i d [ 1 ] ) ∗ t a r g e t [ 1 ] + p r e d s i g m o i d [ 1 ] ∗ ( 1 − t a r g e t [ 1 ] ) = ( 1 − 0.5744425168245771 ) ∗ 0 + 0.5744425168245771 ∗ ( 1 − 0 ) = 0.4255574831754229 pt = (1 - pred_sigmoid[1]) * target[1] + pred_sigmoid[1] * (1 - target[1])=(1 - 0.5744425168245771) * 0 + 0.5744425168245771 * (1 - 0)=0.4255574831754229 pt=(1−predsigmoid[1])∗target[1]+predsigmoid[1]∗(1−target[1])=(1−0.5744425168245771)∗0+0.5744425168245771∗(1−0)=0.4255574831754229。
- 对于第一个样本(正样本), f o c a l w e i g h t = ( a l p h a ∗ t a r g e t [ 0 ] + ( 1 − a l p h a ) ∗ ( 1 − t a r g e t [ 0 ] ) ) ∗ p t . p o w ( g a m m a ) = ( 0.25 ∗ 1 + ( 1 − 0.25 ) ∗ ( 1 − 1 ) ) ∗ 0.377540668528175 4 2 = 0.03522044134486913 focal_weight = (alpha * target[0] + (1 - alpha) * (1 - target[0])) * pt.pow(gamma)=(0.25 * 1 + (1 - 0.25) * (1 - 1)) * 0.3775406685281754^{2}=0.03522044134486913 focalweight=(alpha∗target[0]+(1−alpha)∗(1−target[0]))∗pt.pow(gamma)=(0.25∗1+(1−0.25)∗(1−1))∗0.37754066852817542=0.03522044134486913。
- 对于第二个样本(负样本), f o c a l w e i g h t = ( a l p h a ∗ t a r g e t [ 1 ] + ( 1 − a l p h a ) ∗ ( 1 − t a r g e t [ 1 ] ) ) ∗ p t . p o w ( g a m m a ) = ( 0.25 ∗ 0 + ( 1 − 0.25 ) ∗ ( 1 − 0 ) ) ∗ 0.425557483175422 9 2 = 0.14933689641952515 focal_weight = (alpha * target[1] + (1 - alpha) * (1 - target[1])) * pt.pow(gamma)=(0.25 * 0 + (1 - 0.25) * (1 - 0)) * 0.4255574831754229^{2}=0.14933689641952515 focalweight=(alpha∗target[1]+(1−alpha)∗(1−target[1]))∗pt.pow(gamma)=(0.25∗0+(1−0.25)∗(1−0))∗0.42555748317542292=0.14933689641952515。
-
计算二元交叉熵损失:
- 对于第一个样本(正样本), L B C E = − t a r g e t [ 0 ] log ( p r e d s i g m o i d [ 0 ] ) − ( 1 − t a r g e t [ 0 ] ) log ( 1 − p r e d s i g m o i d [ 0 ] ) = − 1 log ( 0.6224593314718246 ) − ( 1 − 1 ) log ( 1 − 0.6224593314718246 ) = − 0.4941490331089446 L_{BCE}=-target[0]\log(pred_sigmoid[0])-(1 - target[0])\log(1 - pred_sigmoid[0])=-1\log(0.6224593314718246)-(1 - 1)\log(1 - 0.6224593314718246)=-0.4941490331089446 LBCE=−target[0]log(predsigmoid[0])−(1−target[0])log(1−predsigmoid[0])=−1log(0.6224593314718246)−(1−1)log(1−0.6224593314718246)=−0.4941490331089446。
- 对于第二个样本(负样本), L B C E = − t a r g e t [ 1 ] log ( p r e d s i g m o i d [ 1 ] ) − ( 1 − t a r g e t [ 1 ] ) log ( 1 − p r e d s i g m o i d [ 1 ] ) = − 0 log ( 0.5744425168245771 ) − ( 1 − 0 ) log ( 1 − 0.5744425168245771 ) = − 0.5596313035264794 L_{BCE}=-target[1]\log(pred_sigmoid[1])-(1 - target[1])\log(1 - pred_sigmoid[1])=-0\log(0.5744425168245771)-(1 - 0)\log(1 - 0.5744425168245771)=-0.5596313035264794 LBCE=−target[1]log(predsigmoid[1])−(1−target[1])log(1−predsigmoid[1])=−0log(0.5744425168245771)−(1−0)log(1−0.5744425168245771)=−0.5596313035264794。
-
计算焦点损失:
- 对于第一个样本(正样本), l o s s = L B C E ∗ f o c a l w e i g h t = − 0.4941490331089446 ∗ 0.03522044134486913 = − 0.017411441774379015 loss = L_{BCE} * focal_weight=-0.4941490331089446 * 0.03522044134486913=-0.017411441774379015 loss=LBCE∗focalweight=−0.4941490331089446∗0.03522044134486913=−0.017411441774379015。
- 对于第二个样本(负样本), l o s s = L B C E ∗ f o c a l w e i g h t = − 0.5596313035264794 ∗ 0.14933689641952515 = − 0.08354377734184265 loss = L_{BCE} * focal_weight=-0.5596313035264794 * 0.14933689641952515=-0.08354377734184265 loss=LBCE∗focalweight=−0.5596313035264794∗0.14933689641952515=−0.08354377734184265。
(4)与代码对应
pred_sigmoid = pred.sigmoid()
:对应计算 Sigmoid 后的预测概率。target = target.type_as(pred)
:确保真实标签的数据类型与预测结果一致。pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
:计算 p t p_t pt和 ( 1 − p t ) (1 - p_t) (1−pt)的中间值。focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
:计算调制因子。loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weight
:结合二元交叉熵损失和调制因子计算焦点损失。if weight is not None:
及后续代码:如果有样本权重,则根据权重对损失进行调整。loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
:根据指定的 reduction 方法(如mean
、sum
等)和平均因子对损失进行进一步处理。
2. py_focal_loss_with_prob
def py_focal_loss_with_prob(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
Args:
pred (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
target (torch.Tensor): The learning label of the prediction.
The target shape support (N,C) or (N,), (N,C) means one-hot form.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
if pred.dim() != target.dim():
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
target = target.type_as(pred)
pt = (1 - pred) * target + pred * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
2.1 函数与论文关系概述
py_focal_loss_with_prob
函数同样是基于论文《Focal Loss for Dense Object Detection》中提出的焦点损失(Focal Loss)实现的。与py_sigmoid_focal_loss
不同的是,这个函数接受概率作为输入。
2.2 对应公式
-
** modulating factor(调制因子)计算**:
pt = (1 - pred) * target + pred * (1 - target)
,这里的pt
与论文中的含义一致,只不过这里直接使用输入的概率pred
进行计算。focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
,同样是焦点损失中的调制因子部分,其中(\alpha)是平衡因子,(\gamma)是聚焦参数。
-
Focal Loss 公式:
loss = F.binary_cross_entropy(pred, target, reduction='none') * focal_weight
,这里结合了二元交叉熵损失(Binary Cross Entropy)和调制因子,得到最终的焦点损失。二元交叉熵损失公式为(L_{BCE} = -y\log§-(1 - y)\log(1 - p)),其中(y)是真实标签,§是预测概率。焦点损失公式为(FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t))。
2.3 举例解析过程
假设我们有一个三分类问题,预测概率pred
为([0.7, 0.2, 0.1]),真实标签target
为([1, 0, 0])(以 one-hot 形式表示),(\gamma = 2),(\alpha = 0.25)。
-
计算(p_t)和((1 - p_t)):
- 对于第一个类别(真实类别),(p_t = pred[0]=0.7),则((1 - p_t)=0.3)。
- 对于第二个类别,(p_t = pred[1]=0.2),则((1 - p_t)=0.8)。
- 对于第三个类别,(p_t = pred[2]=0.1),则((1 - p_t)=0.9)。
-
计算调制因子:
- 对于第一个类别(真实类别),(pt = (1 - pred[0]) * target[0] + pred[0] * (1 - target[0])=(1 - 0.7) * 1 + 0.7 * (1 - 1)=0.3)。
- 对于第二个类别,(pt = (1 - pred[1]) * target[1] + pred[1] * (1 - target[1])=(1 - 0.2) * 0 + 0.2 * (1 - 0)=0.2)。
- 对于第三个类别,(pt = (1 - pred[2]) * target[2] + pred[2] * (1 - target[2])=(1 - 0.1) * 0 + 0.1 * (1 - 0)=0.1)。
- 对于第一个类别(真实类别),(focal_weight = (alpha * target[0] + (1 - alpha) * (1 - target[0])) * pt.pow(gamma)=(0.25 * 1 + (1 - 0.25) * (1 - 1)) * 0.3^{2}=0.0225)。
- 对于第二个类别,(focal_weight = (alpha * target[1] + (1 - alpha) * (1 - target[1])) * pt.pow(gamma)=(0.25 * 0 + (1 - 0.25) * (1 - 0)) * 0.2^{2}=0.03)。
- 对于第三个类别,(focal_weight = (alpha * target[2] + (1 - alpha) * (1 - target[2])) * pt.pow(gamma)=(0.25 * 0 + (1 - 0.25) * (1 - 0)) * 0.1^{2}=0.0075)。
-
计算二元交叉熵损失:
- 对于第一个类别(真实类别),(L_{BCE}=-target[0]\log(pred[0])-(1 - target[0])\log(1 - pred[0])=-1\log(0.7)-(1 - 1)\log(1 - 0.7)=-0.35667494393873245)。
- 对于第二个类别,(L_{BCE}=-target[1]\log(pred[1])-(1 - target[1])\log(1 - pred[1])=-0\log(0.2)-(1 - 0)\log(1 - 0.2)=-0.22314355131420976)。
- 对于第三个类别,(L_{BCE}=-target[2]\log(pred[2])-(1 - target[2])\log(1 - pred[2])=-0\log(0.1)-(1 - 0)\log(1 - 0.1)=-0.10536051565782628)。
-
计算焦点损失:
- 对于第一个类别(真实类别),(loss = L_{BCE} * focal_weight=-0.35667494393873245 * 0.0225=-0.008025186238621759)。
- 对于第二个类别,(loss = L_{BCE} * focal_weight=-0.22314355131420976 * 0.03=-0.006694306539426293)。
- 对于第三个类别,(loss = L_{BCE} * focal_weight=-0.10536051565782628 * 0.0075=-0.0007902038674336971)。
2.4 与代码对应
if pred.dim()!= target.dim():
:如果预测概率和真实标签的维度不一致,说明真实标签不是 one-hot 形式,需要进行转换。num_classes = pred.size(1)
:获取预测的类别数。target = F.one_hot(target, num_classes=num_classes + 1)
:将真实标签转换为 one-hot 形式,这里的num_classes + 1
是为了处理可能的背景类。target = target[:, :num_classes]
:只取前num_classes
个维度,去除可能的背景类。
target = target.type_as(pred)
:确保真实标签的数据类型与预测结果一致。pt = (1 - pred) * target + pred * (1 - target)
:计算(p_t)和((1 - p_t))的中间值。focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
:计算调制因子。loss = F.binary_cross_entropy(pred, target, reduction='none') * focal_weight
:结合二元交叉熵损失和调制因子计算焦点损失。if weight is not None:
及后续代码:如果有样本权重,则根据权重对损失进行调整。loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
:根据指定的 reduction 方法(如mean
、sum
等)和平均因子对损失进行进一步处理。
3. sigmoid_focal_loss
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
alpha, None, 'none')
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
3.1 函数与论文关系概述
sigmoid_focal_loss
函数是对焦点损失(Focal Loss)的一种实现,与论文《Focal Loss for Dense Object Detection》紧密相关。该函数旨在解决目标检测中正负样本类别不平衡的问题,通过引入调制因子,使模型更加关注难分类的样本。
3.2 对应公式
-
** modulating factor(调制因子)计算**:
- 与前面两个函数类似,这里也涉及到调制因子的计算。公式中的(\gamma)是聚焦参数,用于控制调制的程度;(\alpha)是平衡因子,用于平衡正负样本的权重。
- 调制因子的计算公式为(focal_weight=(alpha * target+(1 - alpha)*(1 - target))*pt.pow(gamma)),其中(pt)的计算方式与前面的函数类似。
-
Focal Loss 公式:
- 该函数最终的损失计算也是结合了二元交叉熵损失和调制因子。焦点损失公式为(FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t))。
3.3 举例解析过程
假设我们有一个二分类问题,预测结果pred
为([0.6, 0.3]),真实标签target
为([1, 0]),(\gamma = 2),(\alpha = 0.25)。
-
首先调用
_sigmoid_focal_loss
函数计算初始损失:loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, 'none')
。这里假设_sigmoid_focal_loss
函数内部实现了与前面函数类似的计算逻辑,得到初始的焦点损失值,假设这个值为([0.1, 0.05])。
-
处理权重:
- 如果有样本权重
weight
,则需要根据权重的形状进行调整。 - 如果
weight.size(0) == loss.size(0)
,说明权重的形状是((num_priors,)),没有类别维度,此时将权重调整为((num_priors, 1))的形状,即weight = weight.view(-1, 1)
。 - 如果权重的形状与损失的形状不同,但元素数量相等,说明可能是扁平的形状,需要根据损失的形状进行调整,即
weight = weight.view(loss.size(0), -1)
。 - 确保权重的维度与损失的维度一致,即
assert weight.ndim == loss.ndim
。
- 如果有样本权重
-
计算最终损失:
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
。这里根据指定的reduction
方法(如mean
、sum
等)和平均因子avg_factor
对损失进行进一步处理,得到最终的损失值。
3.3 与代码对应
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, 'none')
:调用底层的 Cuda 版本的焦点损失计算函数,得到初始的损失值。if weight is not None:
及后续代码:处理样本权重,如果有权重,则根据权重的形状进行调整,并确保权重的维度与损失的维度一致。loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
:根据指定的reduction
方法和平均因子对损失进行进一步处理,得到最终的损失值。
4. FocalLoss
类的相关方法
4.1 源碼
@MODELS.register_module()
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0,
activated=False):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
activated (bool, optional): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
super(FocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.activated = activated
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
The target shape support (N,C) or (N,), (N,C) means
one-hot form.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
if self.activated:
calculate_loss_func = py_focal_loss_with_prob
else:
if pred.dim() == target.dim():
# this means that target is already in One-Hot form.
calculate_loss_func = py_sigmoid_focal_loss
elif torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target,
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls
4.2 解析
- 方法概述
forward
方法是FocalLoss
类的核心方法,用于计算焦点损失(Focal Loss)。它接受预测值pred
、目标值target
以及一些可选参数,如权重weight
、平均因子avg_factor
和损失缩减方式覆盖参数reduction_override
,并返回计算得到的损失张量。
- 参数解释
1). pred
:预测值张量,通常是模型对某个样本的预测结果,可能是一个概率值或未经过激活的原始输出。
2). target
:目标值张量,代表预测值对应的真实标签,可以是 one-hot 编码形式或其他形式。
3). weight
(可选):损失权重张量,用于对每个预测的损失进行加权。默认值为None
。
4). avg_factor
(可选):平均因子,用于平均损失。默认值为None
。
5). reduction_override
(可选):用于覆盖默认损失缩减方式的参数。可以是"none"
、"mean"
或"sum"
。默认值为None
。
- 计算过程
1). 检查损失缩减方式的合法性:
assert reduction_override in (None, 'none', 'mean', 'sum')
:确保reduction_override
参数的值在合法的取值范围内。
2). 确定损失缩减方式:
reduction = reduction_override if reduction_override else self.reduction
:如果reduction_override
不为None
,则使用该参数指定的缩减方式;否则,使用在类初始化时设置的缩减方式self.reduction
。
3). 根据条件选择损失计算函数:
- 如果
self.use_sigmoid
为True
(在类初始化时已断言只能是这种情况):- 如果
self.activated
为True
,表示输入已经是激活后的概率形式,此时将py_focal_loss_with_prob
赋值给calculate_loss_func
。 - 如果输入未激活且预测值和目标值维度相同,说明目标值已经是 one-hot 编码形式,将
py_sigmoid_focal_loss
赋值给calculate_loss_func
。 - 如果输入未激活且有 GPU 可用且预测值在 GPU 上,将
sigmoid_focal_loss
赋值给calculate_loss_func
。 - 如果上述条件都不满足,将目标值转换为 one-hot 编码形式后,将
py_sigmoid_focal_loss
赋值给calculate_loss_func
。
- 如果
- 如果
self.use_sigmoid
为False
,则抛出NotImplementedError
,因为目前只支持使用sigmoid
的焦点损失。
4). 计算损失:
loss_cls = self.loss_weight * calculate_loss_func(pred, target, weight, gamma=self.gamma, alpha=self.alpha, reduction=reduction, avg_factor=avg_factor)
:使用选择的损失计算函数计算损失,并乘以在类初始化时设置的损失权重self.loss_weight
,得到最终的损失值。
5). 返回损失值:
return loss_cls
:返回计算得到的损失张量。
5. FocalCustomLoss
类的相关方法
@MODELS.register_module()
class FocalCustomLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
num_classes=-1,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0,
activated=False):
"""`Focal Loss for V3Det <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
num_classes (int): Number of classes to classify.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
activated (bool, optional): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
super(FocalCustomLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.num_classes = num_classes
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.activated = activated
assert self.num_classes != -1
# custom output channels of the classifier
self.custom_cls_channels = True
# custom activation of cls_score
self.custom_activation = True
# custom accuracy of the classsifier
self.custom_accuracy = True
def get_cls_channels(self, num_classes):
assert num_classes == self.num_classes
return num_classes
def get_activation(self, cls_score):
fine_cls_score = cls_score[:, :self.num_classes]
score_classes = fine_cls_score.sigmoid()
return score_classes
def get_accuracy(self, cls_score, labels):
fine_cls_score = cls_score[:, :self.num_classes]
pos_inds = labels < self.num_classes
acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds])
acc = dict()
acc['acc_classes'] = acc_classes
return acc
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target,
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls
5.1 整体概述
FocalCustomLoss
是一个基于 PyTorch 的神经网络模块,它实现了一种自定义的焦点损失(Focal Loss),用于特定的目标检测任务。这个类继承自nn.Module
,可以作为神经网络中的一个损失层来使用。它的设计主要基于论文《Focal Loss for Dense Object Detection》,旨在解决目标检测中的类别不平衡问题。
5.2 构造函数(__init__
方法)
-
参数解释:
use_sigmoid=True
:表示是否使用 Sigmoid 函数对预测结果进行处理。默认值为True
,目前只支持使用 Sigmoid 焦点损失。num_classes=-1
:要分类的类别数量。这个参数必须在实例化该类时指定一个非负整数。gamma=2.0
:用于计算调制因子的聚焦参数。较高的值会使模型更加关注难分类的样本。alpha=0.25
:平衡因子,用于平衡正负样本的权重。reduction='mean'
:损失的归约方法,可以是'none'
、'mean'
或'sum'
。默认值为'mean'
,表示对损失进行平均。loss_weight=1.0
:损失的权重。可以用来调整该损失在总损失中的贡献。activated=False
:如果为True
,表示输入已经经过激活,可以被视为概率;如果为False
,则输入应被视为对数几率。
-
断言与自定义属性设置:
assert use_sigmoid is True
:确保只支持使用 Sigmoid 焦点损失。assert self.num_classes!= -1
:确保在实例化该类时指定了类别数量。- 设置了一些自定义属性:
self.custom_cls_channels = True
:表示分类器的输出通道是自定义的。self.custom_activation = True
:表示分类器的激活是自定义的。self.custom_accuracy = True
:表示分类器的准确率计算是自定义的。
5.3 方法解析
-
get_cls_channels
方法:- 作用:获取分类器的输出通道数。
- 实现:接收一个参数
num_classes
,并断言该参数与实例化时指定的类别数量相等,然后返回该类别数量。
-
get_activation
方法:- 作用:对分类器的输出进行激活。
- 实现:接收一个参数
cls_score
,表示分类器的输出分数。首先取出前self.num_classes
个类别对应的分数fine_cls_score
,然后对其进行 Sigmoid 激活,得到概率值score_classes
并返回。
-
get_accuracy
方法:- 作用:计算分类器的准确率。
- 实现:接收两个参数
cls_score
和labels
,分别表示分类器的输出分数和真实标签。首先取出前self.num_classes
个类别对应的分数fine_cls_score
,然后找出真实标签小于类别数量的样本索引pos_inds
,对于这些样本计算准确率acc_classes
,并将其封装在一个字典中返回。
-
forward
方法:- 作用:前向传播函数,计算焦点损失。
- 参数解释:
pred
:模型的预测结果,是一个 PyTorch 张量。target
:真实标签,也是一个 PyTorch 张量。weight
:每个样本的损失权重,可选参数。avg_factor
:用于平均损失的因子,可选参数。reduction_override
:用于覆盖默认的损失归约方法的参数,可选。
- 实现过程:
- 首先检查
reduction_override
参数是否合法。 - 如果
self.use_sigmoid
为True
(目前只支持这种情况):- 计算预测结果的类别数量
num_classes
。 - 将真实标签转换为 one-hot 编码形式
target
。 - 设置计算损失的函数为
py_sigmoid_focal_loss
。 - 计算焦点损失
loss_cls
,将损失权重、预测结果、真实标签、聚焦参数、平衡因子、归约方法和平均因子传递给calculate_loss_func
。
- 计算预测结果的类别数量
- 如果
self.use_sigmoid
为False
,则抛出NotImplementedError
,因为目前只支持 Sigmoid 焦点损失。 - 最后返回计算得到的损失
loss_cls
。
- 首先检查
5.4 结合论文分析
这个类的实现与论文《Focal Loss for Dense Object Detection》中的焦点损失概念紧密相关。通过设置聚焦参数gamma
和平衡因子alpha
,可以调整损失函数对易分类和难分类样本的关注程度,从而有效地解决目标检测中的类别不平衡问题。同时,通过自定义的准确率计算方法和激活函数,可以更好地适应特定的目标检测任务。
5.5 举例说明
假设我们有一个三分类问题,预测结果pred
为([[0.7, 0.2, 0.1], [0.3, 0.4, 0.3]]),真实标签target
为([0, 1]),类别数量num_classes = 3
,聚焦参数gamma = 2
,平衡因子alpha = 0.25
,损失归约方法为'mean'
,损失权重loss_weight = 1.0
。
-
实例化
FocalCustomLoss
类:loss_fn = FocalCustomLoss(num_classes=3, gamma=2, alpha=0.25)
。
-
调用
forward
方法计算损失:loss = loss_fn(pred, target)
。- 在
forward
方法中,首先将真实标签转换为 one-hot 编码形式,即target = F.one_hot(target, num_classes=num_classes + 1)
,得到([[1, 0, 0], [0, 1, 0]]),然后取前三个类别,得到([[1, 0, 0], [0, 1, 0]]`。 - 设置计算损失的函数为
py_sigmoid_focal_loss
。 - 调用
calculate_loss_func
计算焦点损失,假设得到的损失值为([0.1, 0.2])。 - 由于损失归约方法为
'mean'
,所以最终的损失为((0.1 + 0.2) / 2 = 0.15`。
通过这个例子,可以看到FocalCustomLoss
类如何根据输入的预测结果和真实标签计算焦点损失,并根据指定的参数进行归约。
三、结语
mmdet/models/losses/focal_loss.py
文件提供了多种实现 Focal Loss 的方式,包括纯 PyTorch 实现、Cuda 加速实现以及自定义的损失类。这些实现方式为目标检测任务中的损失计算提供了灵活的选择,用户可以根据自己的需求选择合适的方式来计算 Focal Loss。同时,通过自定义的损失类,用户还可以根据特定的任务需求进行进一步的扩展和定制。Focal Loss 的引入有助于解决目标检测中的类别不平衡问题,提高模型的性能和泛化能力。