Pytorch中交叉熵损失函数分析

本文详细介绍PyTorch中BinaryCrossEntropy()和CrossEntropyLoss()在二/多分类问题中的使用,包括输入shape要求、激活处理以及计算公式。帮助理解其在实际项目中的正确应用。

引言

本文旨在对pytorch中常用于分类问题的损失函数BinaryCrossEntropy(), CrossEntropy()用法进行一个简要的介绍。常见的文章主要是对这些损失函数的原理进行了数学推导,而本文主要介绍了其输入输出的shape和格式要求,作为一个工具存在。

损失函数

 本文涉及到的损失函数有BCELoss()、BCEWithLogitsLoss()、NLLLOSS()、CrossEntropyLoss(),前两者是二分类问题常用的损失函数,后两者是多分类问题常用的损失函数。列出格式表如下:

输入格式 label的dtype 是否为独热向量 网络输出是否需要激活
BCELoss() (pred:[*],label:[*],二者相同即可) torch.float32
BCEWithLogitsLoss() (pred:[*],label:[*],二者相同即可) torch.float32
NLLLOSS() (pred:[N,C],label:[N,]) torch.int64
CrossEntropyLoss() (pred:[N,C],label:[N,])或 (pred:[N,C],label:[N,C]) torch.int64/torch.float32,torch.float64

 其中CrossEntropyLoss之所以会有label为[N,C]形状却不并不为onehot向量,这是因为这里的label描述的是一个样本属于多个类别的情况,可以认为是属于每一种类别的可能性,也可以认为是软化的onehot向量。

计算方式

(默认在batch上采用平均):

BCELoss()

loss=−1N∑iN[yi⋅log(pi)+(1−yi)⋅log(1−pi)]loss=-\frac{1}{N}\sum_i^{N}[y_i\cdot log(p_i)+ (1-y_i)\cdot log(1-p_i)]loss=N1iN[yilog(pi)+(1yi)lo

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值