目录
1. reduce_loss(loss, reduction)- 基础归约函数
2. weight_reduce_loss(loss, weight=None, reduction='mean')- 加权归约函数
3. weighted_loss(loss_func)- 装饰器(核心接口)
import functools
from torch.nn import functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean'):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over weight region
elif reduction == 'mean':
if weight.size(1) > 1:
weight = weight.sum()
else:
weight = weight.sum() * loss.size(1)
loss = loss.sum() / weight
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
代码实现了一个非常实用的加权损失函数工具集,主要用于简化自定义损失函数的实现,特别是需要支持元素级权重(element-wise weight) 和多种归约模式(reduction) 的场景。这在图像恢复等任务中很常见,例如对图像中不同区域赋予不同的重要性。
下面我们逐部分解析代码的功能和设计思想。
🧠 核心组件解析
1. reduce_loss(loss, reduction)- 基础归约函数
这个函数根据指定的 reduction模式对逐元素损失张量进行归约操作。
- 输入:
loss(Tensor): 任意形状的逐元素损失张量。reduction(str): 归约模式,可选'none','mean','sum'。
- 功能:
- 如果
reduction='none',直接返回原始损失张量,不进行任何聚合。 - 如果
reduction='mean',返回所有元素的平均值。 - 如果
reduction='sum',返回所有元素的总和。
- 如果
- 用途:这是一个基础的归约操作,为后续更复杂的加权归约做准备。
2. weight_reduce_loss(loss, weight=None, reduction='mean')- 加权归约函数
这是核心函数,实现了应用元素级权重并进行归约的逻辑。
- 输入:
loss

最低0.47元/天 解锁文章
17万+

被折叠的 条评论
为什么被折叠?



