【Pytorch基础】torch.nn.BCEWithLogitsLoss样本不均衡的处理

本文探讨了正负样本不均衡问题,并介绍了如何利用PyTorch中的BCEWithLogitsLoss来调整不同类别样本的损失权重,以解决样本不平衡问题。通过设置pos_weight参数,可以有效地对正样本的损失进行加权。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  遇到了正负样本不均衡的问题,正样本数目是负样本的5倍,这样会导致FP率较高。尝试将正样本的loss权重增高,看BCEWithLogitsLoss的源码。

Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)
 
Args:
    weight (Tensor, optional): a manual rescaling weight given to the loss
        of each batch element. If given, has to be a Tensor of size `nbatch`.
    size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
        the losses are averaged over each loss element in the batch. Note that for
        some losses, there are multiple elements per sample. If the field :attr:`size_average`
        is set to ``False``, the losses are instead summed for each minibatch. Ignored
        when reduce is ``False``. Default: ``True``
    reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
        losses are averaged or summed over observations for each minibatch depending
        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
        batch element instead and ignores :attr:`size_average`. Default: ``True``
    reduction (string, optional): Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        ``'mean'``: the sum of the output will be divided by the number of
        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
    pos_weight (Tensor, optional): a weight of positive examples.
            Must be a vector with length equal to the number of classes.

  对其中的参数pos_weight的使用存在疑惑,BCEloss里的例子pos_weight = torch.ones([64]) # All weights are equal to 1,不懂为什么会有64个class,因为BCEloss是针对二分类问题的loss,后经过检索,得知还有多标签分类。
在这里插入图片描述
  多标签分类就是多个标签,每个标签有两个label(0和1),这类任务同样可以使用BCEloss。
在这里插入图片描述
  比如我们有正负两类样本,正样本数量为100个,负样本为400个,我们想要对正负样本的loss进行加权处理,将正样本的loss权重放大4倍,通过这样的方式缓解样本不均衡问题。

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
 
# pos_weight (Tensor, optional): a weight of positive examples.
#            Must be a vector with length equal to the number of classes.

  pos_weight里是一个tensor列表,需要和标签个数相同,比如我们现在是二分类,只需要将正样本loss的权重写上即可。如果是多标签分类,有64个标签,则

Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)

转载来源

[1]转载来源

### 解析 `torch.nn.BCEWithLogitsLoss` 计算出的损失值为 NaN 的原因 当遇到 `torch.nn.BCEWithLogitsLoss` 输出的损失值为 NaN 时,通常是因为输入数据存在问题或模型训练过程中出现了数值不稳定的情况。具体来说: - 输入张量中的某些元素可能导致内部计算溢出或下溢,从而产生无穷大或零除错误。 - 如果标签和预测之间存在不匹配的数据类型或维度差异,也可能引发此类问题。 为了防止这种情况发生并解决问题,可以采取以下措施[^1]: #### 数据预处理与验证 确保输入到损失函数的数据已经过适当预处理,并且满足二分类交叉熵的要求: - 验证所有样本的目标变量只包含 {0, 1} 或者 {-1, 1} 这样的二元类别标记; - 对于预测概率分布,应保证其范围位于 (0, 1) 内;如果使用的是原始得分,则无需额外操作因为 `BCEWithLogitsLoss` 自动应用 Sigmoid 函数转换。 ```python import torch from torch import nn # 假设这是你的模型输出以及对应的ground truth labels predictions = ... # shape: [batch_size] labels = ... # shape: [batch_size] # 将标签转化为浮点数类型 labels = labels.float() # 创建损失实例对象 criterion = nn.BCEWithLogitsLoss() ``` #### 设置合理的超参数配置 调整优化器的学习率和其他相关设置来提高稳定性: - 较低的学习速率有助于减少权重更新幅度,进而降低梯度爆炸的风险; - 使用带有动量项或其他正则化机制(如 L2 正则)的方法可以帮助平滑收敛过程。 #### 添加数值稳定性的技巧 通过引入一些技术手段增强数值稳健性: - 当前版本 PyTorch 已内置了对极端情况下的保护逻辑,在大多数情况下不需要特别干预; - 不过仍建议定期检查是否有异常大的激活值出现,并考虑采用截断策略将其限制在一个安全范围内。 #### 调试工具的应用 利用调试工具定位潜在的问题源码位置: - 启用自动求导引擎中的检测功能 (`torch.autograd.set_detect_anomaly(True)` ) 可帮助发现反向传播期间产生的任何非法运算; - 结合可视化库绘制误差曲线图以便直观观察是否存在突变现象。 最后附上一段完整的代码片段用于测试上述方法的有效性: ```python def train_model(model, dataloader, optimizer, criterion=nn.BCEWithLogitsLoss()): model.train() running_loss = [] for inputs, targets in dataloader: outputs = model(inputs) # Ensure correct dtype and device placement targets = targets.to(outputs.device).float().unsqueeze(-1) loss = criterion(outputs, targets) if not torch.isnan(loss): optimizer.zero_grad() loss.backward() optimizer.step() running_loss.append(loss.item()) else: print('Encountered NaN during training.') break avg_train_loss = sum(running_loss)/len(running_loss) return avg_train_loss ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值