CrossEntropyLoss 等价于 softmax+log+NLLLoss
LogSoftmax等价于softmax+log
可用于文本分类、序列标注等计算损失
使用方法:
# 首先定义该类
loss = torch.nn.CrossEntropyLoss()
#然后传参进去
loss(input, target)
input维度为N*C,是网络生成的值,N为batch_size,C为类别数;
target维度为N,是标注值,非one-hot类型的值;
input = torch.randn(4,3)
target = torch.tensor([0,1,1,2]) #必须为Long类型,是类别的序号
cross_entropy_loss = nn.CrossEntropyLoss()
loss = cross_entropy_loss(input, target)
# 对于序列标注来说,需要reshape一下
input = torch.randn(2,4,3) # 2为batch_size, 4为seq_length,3为类别数
input = input.view(-1,3) # 一共8个token
target = torch.tensor([[0,1,1,2], [2,3,1,0]])
target = target.view(-1)
loss = cross_entropy_loss(input, target) # reduction='mean',默认为mean;
参考文章:pytorch中的CrossEntropyLoss
PyTorch详解NLLLoss和CrossEntropyLoss
对PyTorch中F.cross_entropy()函数的理解
PyTorch学习笔记——softmax和log_softmax的区别、CrossEntropyLoss() 与 NLLLoss() 的区别、log似然代价函数
Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的关系与区别详解
本文探讨了PyTorch中CrossEntropyLoss的使用,它等同于LogSoftmax加上负对数似然损失(NLLLoss)。通常应用于文本分类和序列标注任务。CrossEntropyLoss的输入是N*C维度的网络输出,N为批次大小,C为类别数,而目标是N维度的非one-hot标签。文章引用了几篇参考资料,详细解释了CrossEntropyLoss与其他相关函数的区别。
1269





