[深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

本文深入探讨了PyTorch中NLLLoss与CrossEntropyLoss两种损失函数的使用场景与区别。通过实例说明了NLLLoss适用于logsoftmax处理后的数据,而CrossEntropyLoss直接用于原始输出数据。对比了两种函数在预测与目标值匹配过程中的梯度求导与更新机制。

 

nn.NLLLoss()的参数是经过logsoftmax加工的,而CrossEntropyLoss的是原始输出数据

  1. target = torch.tensor([1, 2])

  2. entropy_out = F.cross_entropy(data, target)

  3. nll_out = F.nll_loss(log_soft, target)

  4. 注意这里的代价函数的特点,无论是采用cross_entropy 还是nll_loss,预测结果与目标值都看起来是两路数据。但是并不影响梯度求导更新过程,最终调整成预测与target目标函数越来越接近

https://zengwenqi.blog.youkuaiyun.com/article/details/96282788

### 交叉熵损失函数的计算原理 在深度学习中,`nn.CrossEntropyLoss()` 是 PyTorch 提供的一个常用损失函数,用于多分类任务。其计算过程结合了 `nn.LogSoftmax()` `nn.NLLLoss()` 两个函数。输入张量的形状通常为 `(minibatch, C)`,其中 `C` 表示类别数。损失函数的计算公式如下: $$ \text{Loss}(x, y) = -\log\left(\frac{\exp(x_y)}{\sum_{c=1}^{C} \exp(x_c)}\right) $$ 其中 $ x $ 是模型的输出(logits),$ y $ 是标签值。标签值并不直接参与计算,而是作为索引,用于选择对应类别的输出值。若设置了权重参数 `weight`,则损失值会根据权重进行调整 [^2]。 ### 使用代码示例 以下是一个简单的代码示例,展示了 `nn.CrossEntropyLoss()` 的使用方法: ```python import torch import torch.nn as nn # 假设输入张量的形状为 (minibatch, C),C 是类别数 logits = torch.tensor([[2.0, 1.0, 0.1], [1.0, 2.0, 0.1]], dtype=torch.float32) # 标签张量的形状为 (minibatch),每个元素的值是类别索引 labels = torch.tensor([0, 1], dtype=torch.long) # 初始化交叉熵损失函数 criterion = nn.CrossEntropyLoss() # 计算损失 loss = criterion(logits, labels) # 输出损失值 print("交叉熵损失:", loss.item()) ``` ### 忽略特定标签值 在某些情况下,可能需要忽略某些标签值,例如无效标签或填充标签。`nn.CrossEntropyLoss()` 提供了 `ignore_index` 参数,用于指定需要忽略的标签值。在计算损失时,这些标签值不会被纳入计算。以下是一个示例: ```python import torch import torch.nn.functional as F # 模型预测输出 pred = torch.tensor([[0.9, 0.1], [0.8, 0.2]], dtype=torch.float32) # 标签张量,其中 -1 表示不参与损失计算 label = torch.tensor([[1], [-1]], dtype=torch.long) # 计算交叉熵损失,并忽略标签值 -1 loss = F.cross_entropy(pred.view(-1, 2), label.view(-1), ignore_index=-1) # 输出损失值 print("忽略特定标签值的交叉熵损失:", loss.item()) ``` 在上述示例中,`ignore_index=-1` 表示在计算损失时,标签值为 `-1` 的样本不会被纳入计算。同时,计算平均损失时,只会考虑实际参与计算的样本数 [^3]。 ### 总结 `nn.CrossEntropyLoss()` 是一个非常高效的损失函数,适用于多分类任务。其计算过程结合了 `nn.LogSoftmax()` `nn.NLLLoss()`,能够直接处理模型的输出标签值。通过设置 `weight` `ignore_index` 参数,可以灵活地调整损失计算方式,适应不同的任务需求 。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值