文章目录
Pytorch框架学习 -3 pytorch损失函数
前置
无参数损失函数
定义
class _Loss(Module):
reduction: str
pytorch设计是真的强,我原本以为损失函数是单独定义的类,没想到他居然也是nn.Module的子类。
初始化方法
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(_Loss, self).__init__()
if size_average is not None or reduce is not None:
self.reduction = _Reduction.legacy_get_string(size_average, reduce)
else:
self.reduction = reduction
-
实现对基类的继承
-
如果size_average或者reduce非空调用
def legacy_get_string(size_average, reduce, emit_warning=True): # type: (Optional[bool], Optional[bool], bool) -> str warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: size_average = True if reduce is None: reduce = True if size_average and reduce: ret = 'mean' elif reduce: ret = 'sum' else: ret = 'none' if emit_warning: warnings.warn(warning.format(ret)) return ret
来推断reduction的应该是的情况
-
否则直接使用传入的reduce
这估计是最友好的pytorch基类了吧!
一个实现的例子L1范数损失
class L1Loss(_Loss):
__constants__ = ['reduction']
def __init__(self, size_average=None,