14.Focal Loss

FocalLoss是一种为了解决目标检测中正负样本极度不平衡问题的损失函数,通过调整易分类样本的权重,让模型在训练时更加关注难分类样本。它基于二分类交叉熵,并引入了类别权重系数α和难度调节因子γ,有效地缓解了一阶段目标检测器的性能瓶颈。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


1.基础介绍

论文:Focal Loss for Dense Object Detection

这篇论文是HeKaiMing团队,2017年08月提交的,聚焦于探讨一阶段目标检测器与二阶段目标检测器的性能差异,**作者发现一阶目标检测的方法准确率不如二阶段目标检测器的主要原因是正负anchor训练样本的极度不平衡,一阶检测头给出的上万个候选框中大部分不包含物体,正负数据比列可达 1 : 100 ∗ ∗ 1:100** 1:100

什么意思呢?负样本数量太大,占总的loss的大部分,而且多是容易分类的,因此使得模型的优化方向更倾向于发现负样本,以前的算法使用OHEM算法来应对这种情况,Online Hard Example Mining,OHEM 每个batch中只取损失值最高的n个样本参与训练,OHEM算法虽然增加了错分类样本的权重,但是OHEM算法忽略了容易分类的样本。

Focal Loss 可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。

2.Focal Loss

2.1常规二分类交叉熵

L o s s b c e = − [ y l o g   p + ( 1 − y ) l o g   ( 1 − p ) ] Loss_{bce} = -[ylog\,p + (1-y)log\, (1-p)] Lossbce=[ylogp+(1y)log(1p)]

y = { 0 , 1 } y=\{0,1\} y={01}是标签, p p p是为正样本的概率

也可以写成:

p t = { p , y = 1 1 − p , o t h e r w i s e p_t=\left\{\begin{matrix} p,y=1\\ 1-p,otherwise \end{matrix}\right. pt={p,y=11p,otherwise

2.2 Focal Loss解读

上面普通的Binary Cross Entropy,BCE等同的对待每个类别的样本,无法处理类别不平衡的情况。

在这里插入图片描述

Focal Loss主要是引入了类别权重系数和简单/困难学习样本调节因子实现了对二分类类别不平衡样本数据的学习.

先写出Focal Loss的公式:

F L ( p t ) = − α t ( 1 − p t ) γ l o g p t FL(p_t) = -\alpha_t(1-p_t)^\gamma logp_t FL(pt)=αt(1pt)γlogpt

α t \alpha_t αt就是类别权重系数,其可以取类别数据频率的反比,以减小数据比较多的类别的损失所占的比重,使网络能从样本少的数据中学习。

α t = { α , y = 1 1 − α , o t h e r w i s e \alpha_t=\left\{\begin{matrix} \alpha,y=1\\ 1-\alpha,otherwise \end{matrix}\right. αt={α,y=11α,otherwise

在论文中模型RetinaNet使用Focal Loss设置的 α = 0.25 \alpha=0.25 α=0.25,强制损失权重比例到 1 : 3 1:3 1:3相对样本 1 : 1000 1:1000 1:1000已改善很多

( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ用来简单困难样本的学习,当 p t p_t pt接近于0时,说明这个正样本很难学习 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ的值接近于1,保持损失值使模型学习,当 p t p_t pt接近于1时,说明样本容易学习,减少损失。 论文中实验验证给出的 γ = 2 \gamma=2 γ=2

如上图,当 γ = 0 \gamma=0 γ=0时,Focal Loss就退化成了普通的BCE损失。

在这里插入图片描述

3.源码实现

import numpy as np
def focal_loss(preds: np.array, target: np.array):
     alpha = 0.25
     gamma = 2
     alpha_t = alpha * target + (1 - alpha) * (1 - target)
     p_t = (1 - preds) * target + preds * (1 - target)

     focal_loss = -alpha_t*p_t.pow(gamma) * [target*np.log(preds)+(1-target) * log(1 - preds)]
     return focal_loss

欢迎访问个人网络日志🌹🌹知行空间🌹🌹


参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值