pytorch中的nn.CrossEntropyLoss()损失

博客围绕使用Pytorch深度学习框架计算分类损失展开,介绍了结合nn.LogSoftmax()和nn.NLLLoss()的损失函数,说明了其输入格式、权重参数设置,还通过具体代码和实例分析了该函数的原理及损失计算过程。

在使用pytorch深度学习框架,计算分类损失时经常会遇到这么一个函数nn.CrossEntropyLoss(),该损失函数结合了nn.LogSoftmax()和nn.NLLLoss()两个函数。它在做分类(具体几类)训练的时候是非常有用的,如下我将对该函数的原理使用代码和实例进行分析。
    
首先输入是size是(minibatch,C)。这里的C是类别数。损失函数的计算如下:
在这里插入图片描述
损失函数中也有权重weight参数设置,若设置权重,则公式为:
在这里插入图片描述
注意这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象才为实际类别。

举个栗子,我们一共有5种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:

举个栗子,我们一共有5个类别,batch_size =2

import torch
import torch.nn as nn
import math

entroy=nn.CrossEntropyLoss()
input=torch.tensor([[ 0.0043, -0.0174, -0.0153,  
`nn.CrossEntropyLoss()` 是 PyTorch 中用于计算交叉熵损失函数,通常用于多分类问题,衡量模型预测的概率分布与真实标签之间的差异[^3]。 #### 含义 交叉熵损失是分类任务中常用的损失函数,它可以衡量模型预测的概率分布和真实标签的概率分布之间的差异。模型预测的概率分布越接近真实标签的概率分布,交叉熵损失就越小。 #### 使用方法 以下是一个简单的使用示例: ```python import torch import torch.nn as nn # 输入数据 logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]]) # [batch_size, num_classes] target = torch.tensor([1, 2]) # 真实类别索引 # 定义损失函数 criterion = nn.CrossEntropyLoss() # 计算损失 loss = criterion(logits, target) print("CrossEntropyLoss:", loss.item()) ``` 在上述代码中,首先导入了必要的库,然后定义了输入数据 `logits` 和真实标签 `target`。接着,创建了 `nn.CrossEntropyLoss` 的实例 `criterion`,并使用该实例计算了损失。 #### 计算过程 从数学上看,`nn.CrossEntropyLoss` 可以分解为 `LogSoftmax` 和 `nn.NLLLoss`,以下是验证代码: ```python import torch import torch.nn as nn # 输入数据 logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]]) # [batch_size, num_classes] target = torch.tensor([1, 2]) # 真实类别索引 # 方法 1:直接用 nn.CrossEntropyLoss ce_loss_fn = nn.CrossEntropyLoss() ce_loss = ce_loss_fn(logits, target) print("CrossEntropyLoss:", ce_loss.item()) # 方法 2:LogSoftmax + nn.NLLLoss log_softmax = nn.LogSoftmax(dim=1) nll_loss_fn = nn.NLLLoss() log_probs = log_softmax(logits) # 计算对数概率 nll_loss = nll_loss_fn(log_probs, target) print("LogSoftmax + NLLLoss:", nll_loss.item()) ``` 这个代码展示了两种计算交叉熵损失的方法,一种是直接使用 `nn.CrossEntropyLoss`,另一种是先使用 `LogSoftmax` 计算对数概率,再使用 `nn.NLLLoss` 计算损失,两种方法的结果是相同的[^2]。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值