图像增强任务中的常用损失函数

目录

🧠 核心组件解析

1. reduce_loss(loss, reduction)- 基础归约函数

2. weight_reduce_loss(loss, weight=None, reduction='mean')- 加权归约函数

3. weighted_loss(loss_func)- 装饰器(核心接口)

💡 代码设计总结与启示

🚀 在图像恢复任务中的应用

关键点与设计思路

一个简单的例子

weigt在size(1)的维度为1的例子

📊 核心概念速览

🧮 计算步骤详解

💡 为什么要乘以 loss.size(1)?

⚖️ 其他 reduction 模式


import functools
from torch.nn import functional as F


def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are 'none', 'mean' and 'sum'.

    Returns:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    else:
        return loss.sum()


def weight_reduce_loss(loss, weight=None, reduction='mean'):
    """Apply element-wise weight and reduce loss.

    Args:
        loss (Tensor): Element-wise loss.
        weight (Tensor): Element-wise weights. Default: None.
        reduction (str): Same as built-in losses of PyTorch. Options are
            'none', 'mean' and 'sum'. Default: 'mean'.

    Returns:
        Tensor: Loss values.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        assert weight.dim() == loss.dim()
        assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
        loss = loss * weight

    # if weight is not specified or reduction is sum, just reduce the loss
    if weight is None or reduction == 'sum':
        loss = reduce_loss(loss, reduction)
    # if reduction is mean, then compute mean over weight region
    elif reduction == 'mean':
        if weight.size(1) > 1:
            weight = weight.sum()
        else:
            weight = weight.sum() * loss.size(1)
        loss = loss.sum() / weight

    return loss


def weighted_loss(loss_func):
    """Create a weighted version of a given loss function.

    To use this decorator, the loss function must have the signature like
    `loss_func(pred, target, **kwargs)`. The function only needs to compute
    element-wise loss without any reduction. This decorator will add weight
    and reduction arguments to the function. The decorated function will have
    the signature like `loss_func(pred, target, weight=None, reduction='mean',
    **kwargs)`.

    :Example:

    >>> import torch
    >>> @weighted_loss
    >>> def l1_loss(pred, target):
    >>>     return (pred - target).abs()

    >>> pred = torch.Tensor([0, 2, 3])
    >>> target = torch.Tensor([1, 1, 1])
    >>> weight = torch.Tensor([1, 0, 1])

    >>> l1_loss(pred, target)
    tensor(1.3333)
    >>> l1_loss(pred, target, weight)
    tensor(1.5000)
    >>> l1_loss(pred, target, reduction='none')
    tensor([1., 1., 2.])
    >>> l1_loss(pred, target, weight, reduction='sum')
    tensor(3.)
    """

    @functools.wraps(loss_func)
    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
        # get element-wise loss
        loss = loss_func(pred, target, **kwargs)
        loss = weight_reduce_loss(loss, weight, reduction)
        return loss

    return wrapper

代码实现了一个非常实用的​​加权损失函数工具集​​,主要用于简化自定义损失函数的实现,特别是需要支持​​元素级权重(element-wise weight)​​ 和多种​​归约模式(reduction)​​ 的场景。这在图像恢复等任务中很常见,例如对图像中不同区域赋予不同的重要性。

下面我们逐部分解析代码的功能和设计思想。

🧠 核心组件解析

1. reduce_loss(loss, reduction)- 基础归约函数

这个函数根据指定的 reduction模式对逐元素损失张量进行归约操作。

  • 输入​​:
    • loss(Tensor): 任意形状的逐元素损失张量。
    • reduction(str): 归约模式,可选 'none''mean''sum'
  • ​功能​​:
    • 如果 reduction='none',直接返回原始损失张量,不进行任何聚合。
    • 如果 reduction='mean',返回所有元素的平均值。
    • 如果 reduction='sum',返回所有元素的总和。
  • ​用途​​:这是一个基础的归约操作,为后续更复杂的加权归约做准备。
2. weight_reduce_loss(loss, weight=None, reduction='mean')- 加权归约函数

这是核心函数,实现了​​应用元素级权重并进行归约​​的逻辑。

  • 输入​​:
    • loss
### 常用图像分类损失函数深度学习领域,尤其是针对图像分类的任务,多种损失函数被广泛应用来优化模型性能。这些损失函数不仅影响训练过程的有效性,也决定了最终模型的表现质量。 #### Softmax Loss Softmax Loss 是一种常见的多类别分类损失函数,在神经网络的最后一层应用 Softmax 函数将 logit 转换成概率分布形式,随后计算交叉熵作为损失值[^4]。这种机制能够有效地衡量预测标签与实际标签之间的差异,并引导参数更新以减小这一差距。 ```python import torch.nn.functional as F def softmax_loss(output, target): return F.cross_entropy(output, target) ``` #### Weighted Softmax Loss 为了应对数据集中各类别样本数量不平衡的情况,Weighted Softmax Loss 对不同类别的权重进行了调整,使得少数类的重要性得到提升,从而改善整体泛化能力[^2]。 #### Focal Loss 由何凯明等人提出的 Focal Loss 主要应用于处理目标检测中的类别失衡问题。它通过对标准 Cross Entropy 加入动态加权因子的方式降低了易分样本所占比例较高的负面影响,增强了难负样本的学习效果[^3]。 ```python class FocalLoss(nn.Module): def __init__(self, gamma=2., alpha=.25): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss return torch.mean(F_loss) ``` #### L-Softmax Loss / Hinge Loss / Exponential Loss / Logistic Loss 除了上述提到的几种典型方法外,还有一些变体如 L-Softmax Loss、Hinge Loss 和 Exponential Loss 或者 Logistic Loss 也被广泛研究并应用于特定场景下的改进版本中。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值