关于多标签分类任务的损失函数和评价指标的一点理解

多标签分类详解
本文深入探讨了多标签分类任务中的损失函数BCEWithLogitsLoss与CrossEntropy的区别,以及它们在不同场景下的适用性。此外,还介绍了多标签任务中的评价指标,并重点解释了samplesavg指标的意义。

关于多标签分类任务的损失函数和评价指标的一点理解

之前有接触到多标签分类任务,但是主要关注点都放在模型结构中,最近关于多标签分类任务进行了一个讨论,发现其中有些细节不是太清楚,经过查阅资料逐渐理解,现在此记录。

多标签分类任务损失函数

在二分类、多分类任务中通常使用交叉熵损失函数,即Pytorch中的CrossEntorpy,但是在多标签分类任务中使用的是BCEWithLogitsLoss函数。

BCEWithLogitsLoss与CrossEntorpy的不同之处在于计算样本所属类别概率值时使用的计算函数不同:
1)CrossEntorpy使用softmax函数,即将模型输出值作为softmax函数的输入,进而计算样本属于每个类别的概率,softmax计算得到的类别概率值加和为1。
2)BCEWithLogitsLoss使用sigmoid函数,将模型输出值作为sigmoid函数的输入,计算得到的多个类别概率值加和不一定为1。

共同点是计算概率值后都继续计算预测概率值和真实标签之间的交叉熵作为最终的损失函数值。

为什么在多标签任务中使用BCEWithLogitsLoss(sigmoid)函数呢?个人理解如下:
1)二分类/多分类任务是在两个/多个类别中取出一个类别,并且各个类别之间是互斥的,因此要保证多个类别的概率值加和为1(在类别概率值加和为1的情况下,一个类别概率值增加时必然有其他类别概率值减小,体现了各个类别之间的互斥),并且最终取出概率值最大的类别。
2)多分类任务是在多个类别中取出一个或多个类别,各个类别之间不互斥,因此无需保证各类别概率加和为1,只需要计算样本属于每一个类别的概率,如果样本属于某一类别的概率高于阈值则代表样本属于该类别(此时类别概率值加和不一定为1),例如样本A经过BCEWithLogitsLoss函数计算后得到属于类别1、类别2、类别3和类别4的概率值分别为[0.6,0.7,0.3,0.4],阈值为0.5,则样本A同时属于类别1和类别2。

使用方法如下:


                
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值