1.简介
接下来就是CrossEntropyLoss了,在分类任务中经常用到。建议大家先看看上一篇NLLLoss的文章再来看这篇,会比较好理解。
2.CrossEntropyLoss
CrossEntropyLoss的计算公式如下:
其中是每个类别的权重,默认的全为1,
表示对应target那一类的概率。可以看到相比于NLLLoss而言,就是只是将输入进行了Softmax之后再log,所以说NLLLoss和CrossEntropyLoss也可以表示为以下关系:
3.思考
首先我们来看为什么这个叫CrossEntropyLoss。交叉熵(Cross Entropy)用来表达两个概率分布之间的相似性,熵越大则表示差别越大,损失值也就越大。
交叉熵的公式为:
其中表示对应真值类别,
表示对应类别概率。
当标签为0-1分布时,上述公式其实就与NLLLoss的计算公式相同,即只计算对应标签的概率作为loss。
通常情况下,网络模型的输出不是0-1分布,这是我们就需要使用Softmax将输出转换为概率分布,即累加和为1,Softmax之后的输出就对应了,加上log就对应了交叉熵的公式。
4.pytorch代码
以下代码为pytorch官方CrossEntropyLoss代码,可以看到里面有几个参数,我们大多数情况下使用默认参数设置就好。
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)
其中:
- weight表示每个类别的权重,当标签不平衡的时候可以使用来防止过拟合。
- size_average表示是否将样本的loss进行平均之后输出,默认为true。
- ignore_index表示忽略某一类别,不想训练某些类别时可用。
- reduce表示是否将输出进行压缩,默认为true。当它为false的时候就会无视size_average。
- reduction表示用怎么的方法进行reduce。可以设置为'none','mean','sum'。
- label_smoothing表示标签的平滑系数,类似于soft label,将0-1分布软化为概率分布
这里需要注意的是label_smoothing会导致标签不再是0-1分布,因此计算方法不再是基于NLLLoss的计算方式,而是基于交叉熵的公式计算。
import torch
import torch.nn as nn
a = torch.randn(3, 5)
b = torch.Tensor([0, 4, 1]).long()
criterion = nn.CrossEntropyLoss()
c = criterion(a, b)
print(c)
业务合作/学习交流+v:lizhiTechnology
如果想要了解更多损失函数相关知识,可以参考我的专栏和其他相关文章:
【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-优快云博客
【损失函数】(二) L2Loss原理 & pytorch代码解析_l2 loss-优快云博客
【损失函数】(三) NLLLoss原理 & pytorch代码解析_nll_loss-优快云博客
【损失函数】(四) CrossEntropyLoss原理 & pytorch代码解析_crossentropyloss 权重-优快云博客
【损失函数】(五) BCELoss原理 & pytorch代码解析_bce损失函数源码解析-优快云博客
如果想要了解更多深度学习相关知识,可以参考我的其他文章: