tensorflow分类的loss函数_图解Focal Loss以及Tensorflow实现(二分类、多分类)

本文详细介绍了Focal Loss,一种用于解决分类任务中类别不平衡问题的损失函数。通过图解和实例展示了Focal Loss如何衰减简单样本的损失,并提供了二分类和多分类的Tensorflow实现代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

总体上讲,Focal Loss是一个缓解分类问题中类别不平衡、难易样本不均衡的损失函数。首先看一下论文中的这张图:

解释:

横轴是ground truth类别对应的概率(经过sigmoid/softmax处理过的logits),纵轴是对应的loss值;

蓝色的线(gamma=0),就是原始交叉熵损失函数,可以明显看出ground truth的概率越大,loss越小,符合常识;

除了蓝色的线,其他几个都是Focal Loss的线,其实原始交叉熵损失函数是Focal Loss的特殊版本(gamma=0)

其他几个Focal Loss线都在蓝色下边,可以看出Focal Loss的作用就是【衰减】;

从图中可以看出,ground truth的概率越大(即容易分类的简单样本),衰减越厉害,也就是大大降低了简单样本的loss;

从图中可以看出,ground truth的概率越小(即不易分类的困难样本),也是有衰减的,但是衰减的程度比较小;

下边是我自己模拟的一组数据,一组固定的logits=[0+epsilon, 0.1, 0.2, ..., 0.9, 1.0-epsilon],然后假设ground truth分别是0、1、2、...、9、10的时候,gamma=0、0.5、1、2、...、8、16对应的loss。

例如第3行第1列的2.75表示,ground truth是类别2,即对应的logits是0.2,gamma=0的时候,loss=2.75(gamma=0,就是原始的多分类交叉熵)。

根据上表可以得到下边的图:

从上图可以看出,随着gamma增大,整体loss都下降了,但是logits相对越高(这个例子中最大logits=1),下降的倍数越大。从上表的最后一列也可以看出来,g

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值