引言
本文旨在对pytorch中常用于分类问题的损失函数BinaryCrossEntropy(), CrossEntropy()用法进行一个简要的介绍。常见的文章主要是对这些损失函数的原理进行了数学推导,而本文主要介绍了其输入输出的shape和格式要求,作为一个工具存在。
损失函数
本文涉及到的损失函数有BCELoss()、BCEWithLogitsLoss()、NLLLOSS()、CrossEntropyLoss(),前两者是二分类问题常用的损失函数,后两者是多分类问题常用的损失函数。列出格式表如下:
| 输入格式 | label的dtype | 是否为独热向量 | 网络输出是否需要激活 | |
|---|---|---|---|---|
| BCELoss() | (pred:[*],label:[*],二者相同即可) | torch.float32 | 否 | 是 |
| BCEWithLogitsLoss() | (pred:[*],label:[*],二者相同即可) | torch.float32 | 否 | 否 |
| NLLLOSS() | (pred:[N,C],label:[N,]) | torch.int64 | 否 | 是 |
| CrossEntropyLoss() | (pred:[N,C],label:[N,])或 (pred:[N,C],label:[N,C]) | torch.int64/torch.float32,torch.float64 | 否 | 否 |
其中CrossEntropyLoss之所以会有label为[N,C]形状却不并不为onehot向量,这是因为这里的label描述的是一个样本属于多个类别的情况,可以认为是属于每一种类别的可能性,也可以认为是软化的onehot向量。
计算方式
(默认在batch上采用平均):
BCELoss()
loss=−1N∑iN[yi⋅log(pi)+(1−yi)⋅log(1−pi)]loss=-\frac{1}{N}\sum_i^{N}[y_i\cdot log(p_i)+ (1-y_i)\cdot log(1-p_i)]loss=−N1∑iN[yi⋅log(pi)+(1−yi)⋅lo

本文详细介绍PyTorch中BinaryCrossEntropy()和CrossEntropyLoss()在二/多分类问题中的使用,包括输入shape要求、激活处理以及计算公式。帮助理解其在实际项目中的正确应用。
最低0.47元/天 解锁文章

被折叠的 条评论
为什么被折叠?



