文章目录
Pytorch框架学习 -3 pytorch损失函数
前置
无参数损失函数
定义
class _Loss(Module):
reduction: str
pytorch设计是真的强,我原本以为损失函数是单独定义的类,没想到他居然也是nn.Module的子类。
初始化方法
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(_Loss, self).__init__()
if size_average is not None or reduce is not None:
self.reduction = _Reduction.legacy_get_string(size_average, reduce)
else:
self.reduction = reduction
-
实现对基类的继承
-
如果size_average或者reduce非空调用
def legacy_get_string(size_average, reduce, emit_warning=True): # type: (Optional[bool], Optional[bool], bool) -> str warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: size_average = True if reduce is None: reduce = True if size_average and reduce: ret = 'mean' elif reduce: ret = 'sum' else: ret = 'none' if emit_warning: warnings.warn(warning.format(ret)) return ret来推断reduction的应该是的情况
-
否则直接使用传入的reduce
这估计是最友好的pytorch基类了吧!
一个实现的例子L1范数损失
class L1Loss(_Loss):
__constants__ = ['reduction']
def __init__(self, size_average=None,

本文主要介绍了Pytorch框架中的损失函数,包括无参数损失函数和有参数损失函数。无参数损失函数如L1范数损失,其继承自nn.Module,并实现了forward方法。其他损失函数大多遵循相同模板实现。有参数损失函数则在nn.Module的基础上注册参数。
最低0.47元/天 解锁文章
4525

被折叠的 条评论
为什么被折叠?



