使用Pytorch框架进行深度学习任务,特别是分类任务时,经常会用到如下:
import torch.nn as nn
criterion = nn.CrossEntropyLoss().cuda()
loss = criterion(output, target)
即使用torch.nn.CrossEntropyLoss()作为损失函数。
那nn.CrossEntropyLoss()内部到底是啥??
nn.CrossEntropyLoss()是torch.nn中包装好的一个类,对应torch.nn.functional中的cross_entropy。
此外,nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合(将两者结合到一个类中)。
nn.logSoftmax()
定义如下:
从公式看,其实就是先softmax在log。
nn.NLLLoss()
定义如下:
此loss期望的target是类别的索引 (0 to N-1, where N &#