欢迎访问个人网络日志🌹🌹知行空间🌹🌹
1.基础介绍
这篇论文是HeKaiMing
团队,2017年08月提交的,聚焦于探讨一阶段目标检测器与二阶段目标检测器的性能差异,**作者发现一阶目标检测的方法准确率不如二阶段目标检测器的主要原因是正负anchor
训练样本的极度不平衡,一阶检测头给出的上万个候选框中大部分不包含物体,正负数据比列可达
1
:
100
∗
∗
1:100**
1:100∗∗。
什么意思呢?负样本数量太大,占总的loss的大部分,而且多是容易分类的,因此使得模型的优化方向更倾向于发现负样本,以前的算法使用OHEM算法来应对这种情况,Online Hard Example Mining,OHEM 每个batch中只取损失值最高的n个样本参与训练,OHEM算法虽然增加了错分类样本的权重,但是OHEM算法忽略了容易分类的样本。
Focal Loss
可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。
2.Focal Loss
2.1常规二分类交叉熵
L o s s b c e = − [ y l o g p + ( 1 − y ) l o g ( 1 − p ) ] Loss_{bce} = -[ylog\,p + (1-y)log\, (1-p)] Lossbce=−[ylogp+(1−y)log(1−p)]
y = { 0 , 1 } y=\{0,1\} y={0,1}是标签, p p p是为正样本的概率
也可以写成:
p t = { p , y = 1 1 − p , o t h e r w i s e p_t=\left\{\begin{matrix} p,y=1\\ 1-p,otherwise \end{matrix}\right. pt={p,y=11−p,otherwise
2.2 Focal Loss解读
上面普通的Binary Cross Entropy,BCE等同的对待每个类别的样本,无法处理类别不平衡的情况。
Focal Loss
主要是引入了类别权重系数和简单/困难学习样本调节因子实现了对二分类类别不平衡样本数据的学习.
先写出Focal Loss
的公式:
F L ( p t ) = − α t ( 1 − p t ) γ l o g p t FL(p_t) = -\alpha_t(1-p_t)^\gamma logp_t FL(pt)=−αt(1−pt)γlogpt
α t \alpha_t αt就是类别权重系数,其可以取类别数据频率的反比,以减小数据比较多的类别的损失所占的比重,使网络能从样本少的数据中学习。
α t = { α , y = 1 1 − α , o t h e r w i s e \alpha_t=\left\{\begin{matrix} \alpha,y=1\\ 1-\alpha,otherwise \end{matrix}\right. αt={α,y=11−α,otherwise
在论文中模型RetinaNet
使用Focal Loss
设置的
α
=
0.25
\alpha=0.25
α=0.25,强制损失权重比例到
1
:
3
1:3
1:3相对样本
1
:
1000
1:1000
1:1000已改善很多
( 1 − p t ) γ (1-p_t)^\gamma (1−pt)γ用来简单困难样本的学习,当 p t p_t pt接近于0时,说明这个正样本很难学习 ( 1 − p t ) γ (1-p_t)^\gamma (1−pt)γ的值接近于1,保持损失值使模型学习,当 p t p_t pt接近于1时,说明样本容易学习,减少损失。 论文中实验验证给出的 γ = 2 \gamma=2 γ=2。
如上图,当
γ
=
0
\gamma=0
γ=0时,Focal Loss
就退化成了普通的BCE损失。
3.源码实现
import numpy as np
def focal_loss(preds: np.array, target: np.array):
alpha = 0.25
gamma = 2
alpha_t = alpha * target + (1 - alpha) * (1 - target)
p_t = (1 - preds) * target + preds * (1 - target)
focal_loss = -alpha_t*p_t.pow(gamma) * [target*np.log(preds)+(1-target) * log(1 - preds)]
return focal_loss
欢迎访问个人网络日志🌹🌹知行空间🌹🌹