crossentropy java_Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss

本文详细解析了Pytorch中的CrossEntropyLoss函数,包括其内部实现原理,如何整合nn.LogSoftmax()和nn.NLLLoss()。同时,通过实例展示了如何处理one-hot编码的target,以及CrossEntropyLoss与nn.LogSoftmax()+nn.NLLLoss()计算结果的等价性。文章还补充了torch.nn和torch.nn.functional中相关函数的对应关系。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用Pytorch框架进行深度学习任务,特别是分类任务时,经常会用到如下:

import torch.nn as nn

criterion = nn.CrossEntropyLoss().cuda()

loss = criterion(output, target)

即使用torch.nn.CrossEntropyLoss()作为损失函数。

那nn.CrossEntropyLoss()内部到底是啥??

nn.CrossEntropyLoss()是torch.nn中包装好的一个类,对应torch.nn.functional中的cross_entropy。

此外,nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合(将两者结合到一个类中)。

nn.logSoftmax()

定义如下:

7d8942e3e6b4ca38b83ef29daf85219c.png

从公式看,其实就是先softmax在log。

nn.NLLLoss()

定义如下:

52643a9a223cf5665da1304a0295fb35.png

此loss期望的target是类别的索引 (0 to N-1, where N &#

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值