torch.nn.CrossEntropyLoss踩坑记录

本文通过实例详细解析了PyTorch中LogSoftmax、NLLLoss和CrossEntropyLoss的使用及相互关系。作者首先展示了如何手动实现LogSoftmax函数,并通过对比验证了其正确性。接着,介绍了NLLLoss的计算过程,同样通过手动计算与库函数比较,证实了NLLLoss是对每个样本真实类别概率取log后的平均值。最后,指出CrossEntropyLoss等价于LogSoftmax和NLLLoss的组合。作者在实践中遇到的问题是训练模型后所有样本被预测为同一类,这可能与模型结构、优化器设置或数据预处理有关。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.nn.LogSoftmax

model = nn.Linear(5, 2)
x = torch.randn(5, 5)
y = model(x)
y
Out[81]: 
tensor([[ 0.1358, -0.7312],
        [-0.1503, -0.2752],
        [-0.3823, -0.2795],
        [-1.2358, -0.0971],
        [-0.9683,  0.6525]], grad_fn=<AddmmBackward>)

# 用torch得到的结果
nn.LogSoftmax()(y)
Out[82]: 
tensor([[-0.3508, -1.2178],
        [-0.6326, -0.7575],
        [-0.7458, -0.6431],
        [-1.4164, -0.2778],
        [-1.8012, -0.1804]], grad_fn=<LogSoftmaxBackward>)

# 下面我们自己手动实现LogSoftmax
y_exp = y.exp()
y_exp
Out[85]: 
tensor([[1.1455, 0.4814],
        [0.8604, 0.7594],
        [0.6823, 0.7562],
        [0.2906, 0.9074],
        [0.3797, 1.9203]], grad_fn=<ExpBackward>)
lsm = torch.log(y_exp / y_exp.sum(-1).unsqueeze(1))
lsm
Out[95]: 
tensor([[-0.3508, -1.2178],
        [-0.6326, -0.7575],
        [-0.7458, -0.6431],
        [-1.4164, -0.2778],
        [-1.8012, -0.1804]], grad_fn=<LogBackward>)
# 可以发现,得到的结果和上面一样

由上可以验证公式 L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) LogSoftmax(x_i)=log(\frac{exp(x_i)}{\sum_jexp(x_j)}) LogSoftmax(xi)=log(jexp(xj)exp(xi))

torch.nn.NLLLoss

# 接上面的代码

# 先定义类别标签
target = torch.randint(0, 2, (5,))
target
Out[110]: tensor([0, 1, 1, 0, 1])

# 用torch封装的函数得到的结果
nlll = torch.nn.NLLLoss()(lsm, target)
nlll
Out[112]: tensor(0.6697, grad_fn=<NllLossBackward>)

# 下面我们自己手动实现NLLLoss(Negative Log Likely Loss)
sum = 0
for i in range(5):
    ...:     sum += -lsm[i, target[i]]
nlll = sum / target.size(0)
nlll
Out[124]: tensor(0.6697, grad_fn=<DivBackward0>)

以上可以验证,torch.nn.NLLLoss是把每个样本对应的真实类别的概率取log之后取平均。

torch.nn.CrossEntropyLoss

torch.nn.CrossEntropyLoss()(y, target)
Out[130]: tensor(0.6697, grad_fn=<NllLossBackward>)
# 与之前得到的loss一致

以上可以验证,nn.CrossEntropyLoss = nn.LogSoftmax + nn.NLLLoss


想说点什么呢

就是想说,网络得到的输出最后一层不需要加softmax就可以直接放进CrossEntropyLoss计算loss。
这两天写了个baseline,结果训练之后的分类结果总是奇迹般的得到一个类别的输出(就是一个10分类,所有的样本都预测为第0类)。
花了两天的时间,得到了这篇博客。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值