Focal loss 详细解读

针对one-stage目标检测中正负样本严重不平衡的问题,Focal Loss通过引入平衡因子α及聚焦参数γ,有效提升了模型对难以分类样本的学习能力,同时平衡了正负样本的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Focal loss所解决的问题是在object detection领域由正负样本不平衡,而导致one-stage detector方法的准确率不如two-stage detector。one-stage detector主要是在一张图像上生成成千上万的candidate locations,这些candidate location中包含所需要的检测目标只有一部分区域即正样本,大多数区域是与检测目标不相关的背景即负样本。这样就会造成candidate location中,正样本数远远小于负样本数。根据交叉熵函数:

Loss = -\log y{}' - (1-y)\log \left ( 1-y{}' \right )

Loss = \begin{cases} -log(y'),y=1\\ -log(1-y'),otherwise \end{cases}

从上面公式可以做出图像

y’表示分成正样本的概率,可以看出当y’越大时,loss越小。所以当真值为1时,y’越大说明样本容易分类为标签1,所以为简单样本,loss越小,对模型更新贡献较少。y’越小说明样本不容易分类成标签1,所以为复杂样本,loss会大,对模型更新显著。当真值为0时,y’越大说明模型容易把错样本分成正样本,loss=-log(1-y’),loss偏大,模型容易更新优化。y’越小说明模型容易正确地把样本正确地分为负样本,loss偏小,即对模型更新效果不显著。这样一来,在one-stage detector中会出现以下问题:

1.当数据集中简单正样本偏多,由于loss偏低则模型更新较为缓慢,且可能无法优化至最优。

2.在candidate location中,作为背景的负样本偏多,且多是容易分类,导致模型优化方向会偏向于容易出检测负样本,难检测出正样本,且训练过程中loss偏低模型迭代缓慢。

Focal loss的baseline  

在原有的Loss = \begin{cases} -log(y'),y=1\\ -log(1-y'),otherwise \end{cases}基础之上,设pt = \begin{cases} y',y=1\\ 1-y',otherwise \end{cases},则Loss = -log(p_t)

在这个loss前加一个平衡因子alpha

\alpha _t = \begin{cases} a,y=1\\ 1-a,otherwise \end{cases}

以此来控制正负本的loss权重,如a=0.4,则正样本贡献为0.4,负样本贡献为0.6。

有了这样的正负样本控制因子之后,并不能控制困难和简单样本的权重。

因此focal loss在此之上又增加了一个Focusing parameter,gamma

(1-p_t)^{}\gamma,此时loss函数公式为

Loss = \alpha_t \cdot \left (1-p_t \right )^{\gamma }\cdot \left( -\log \left ( p_t \right )\right )

我们可以看出,pt的值越大,则说明为简单样本,-\log \left ( p_t \right )越大 ,\left ( 1-p_t \right )^{\gamma }值越小。p_t值越小,则说明为困难样本,-\log \left ( p_t \right )越小,\left ( 1-p_t \right )^{\gamma }越大,则说明为困难样本。

通过focal loss,可以使得:

1.平衡因子alpha控制正负样本对总loss的共享权重,来平衡模型对正负样本的检测能力。

2.注意因子gamma将模型中难分的正样本的loss权重调高,loss贡献较多,使模型对难分的样本分类能力增强。

### Focal Loss 论文 PDF 原版下载 Focal Loss 的原始论文名为《Focal Loss for Dense Object Detection》,由 Lin et al. 提出。该论文的主要贡献在于引入了一种新的损失函数——Focal Loss,用于解决密集目标检测中的类别不平衡问题[^3]。 以下是获取原版论文的方法: 1. **官方链接**: 可通过 ArXiv 平台访问论文的公开版本。具体地址为: [https://arxiv.org/abs/1708.02002](https://arxiv.org/abs/1708.02002) 2. **GitHub 资源**: 如果需要更详细解读或相关实现代码,可以通过 GitHub 上的相关项目找到补充资料。例如,您提到的一个资源已经提供了关于 Focal Loss 的公式推导和解释[^1]。 3. **学术搜索引擎**: 使用 Google Scholar 或其他学术搜索引擎输入关键词 “Focal Loss for Dense Object Detection”,即可找到多种格式的下载选项。 #### 关于 Focal Loss 的核心概念 Focal Loss 是在标准交叉熵基础上改进的一种损失函数,其主要目的是降低容易分类样本的影响,使得模型能够更多关注难以分类的样本。公式定义如下: \[ FL(p_t) = -(1-p_t)^\gamma \log(p_t) \] 其中 \( p_t \) 表示预测的概率值,\( \gamma \) 是可调参数,通常取值大于零。 ```python import torch import torch.nn as nn class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim()>2: input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W input = input.transpose(1,2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1,1) logpt = nn.functional.log_softmax(input, dim=-1) logpt = logpt.gather(1,target) logpt = logpt.view(-1) pt = torch.exp(logpt) if self.alpha is not None: if self.alpha.type()!=input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0,target.data.view(-1)) logpt = logpt * at loss = -1 * (1-pt)**self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() ``` 上述代码实现了 Focal Loss 的 PyTorch 版本,便于实际应用时参考。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值