关于pytorch交叉熵损失函数使用weight参数来平衡各个类别,即加权交叉熵损失函数

总结

torch.nn.CrossEntropyLoss(weight, reduction=“mean”)的计算方式和一般做法不一样,进行加权计算后,不是直接loss除以batch数量,而是batch中每个数据的标签对应的类别权重 把batch个权重加起来作为除数,被除数是loss

交叉熵损失函数原理

这里就不做介绍了,很多博客和视频已经讲的很清楚了。这里引用一下@b站up同济子豪兄之前的一次动态截图
具体可以搜索这个动态,关键词交叉熵

不带权重的交叉熵计算

这里已经也有很多介绍了,不带权重的情况下,torch.nn.CrossEntropyLoss(reduction=“mean”)的计算和公式里一样。
另外,虽然pytorch里面CrossEntropyLoss的target输入要求是标签或者Probabilities,但是target从标签(batch_size, 1(num_classes))转换成独热编码(torch.nn.functional.one_hot),输出结果是一样的。我感觉是转换成独热编码就相当于target是那个Probabilities了,具体计算过程没细看,结果和手搓的代码一样。

import torch.nn as nn
import torch.nn.functional as F

weight=torch.tensor([1,1,1,1]).float()  # 这里权重都是1,所以结果一样
loss_func_mean = nn.CrossEntropyLoss(weight, reduction="mean")
def ce_loss(y_pred, y_true, weight):

    # 计算 log_softmax,这是更稳定的方式
    log_probs = F.log_softmax(y_pred, dim=-1)

    # 计算加权损失
    loss = -(weight * y_true * log_probs).sum(dim=-1).mean()

    return loss
pre_data = torch.tensor([[0.8, 0.5, 0.2, 0.5],
                         [0.2, 0.9, 0.3, 0.2],
                         [0.4, 0.3, 0.7, 0.1],
                         [0.1, 0.2, 0.4, 0.8]], dtype=torch
PyTorch中的空间加权交叉熵损失函数是一种用于图像分割任务的损失函数,它结合了交叉熵损失和空间加权的思想。在图像分割任务中,我们希望同时考虑像素分类的准确性和像素位置的重要性,以更好地处理图像边缘等关键区域。 下面是一个简单的示例代码,展示了如何使用PyTorch实现空间加权交叉熵损失函数: ```python import torch import torch.nn as nn class SpatialWeightedCrossEntropyLoss(nn.Module): def __init__(self, weight=None, size_average=True): super(SpatialWeightedCrossEntropyLoss, self).__init__() self.weight = weight self.size_average = size_average def forward(self, input, target): # 计算交叉熵损失 log_softmax = nn.functional.log_softmax(input, dim=1) loss = nn.functional.nll_loss(log_softmax, target, weight=self.weight, reduction='none') # 计算空间加权损失 spatial_weight = torch.arange(0, input.size(2), dtype=torch.float32) / input.size(2) spatial_weight = spatial_weight.unsqueeze(0).unsqueeze(2).expand_as(loss).to(loss.device) weighted_loss = spatial_weight * loss # 计算平均损失 if self.size_average: return torch.mean(weighted_loss) else: return torch.sum(weighted_loss) ``` 在这个示例中,我们定义了一个名为SpatialWeightedCrossEntropyLoss的自定义损失函数类,继承自nn.Module。在forward方法中,我们首先计算了交叉熵损失,然后使用torch.arange函数生成了一个空间权重张量,该张量的大小与输入张量的大小相同。最后,将空间权重乘以交叉熵损失,得到最终的空间加权损失。 你可以根据你的具体需求,对这个示例代码进行修改和调整。希望对你有帮助!如果你有任何其他问题,请随时提问。
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值