torch.nn.init(nn/init.py)
该文件定义了一些参数(权重)初始化方法
代码:
导入的包:
import math
import random
import torch
from torch.autograd import Variable
def calculate_gain(nonlinearity, param=None):
#对于给定的非线性函数,返回推荐的增益值,用于后面的Xavier初始化
#不同的增益值如下:
# ============ ==========================================
非线性函数 增益值
# ============ ==========================================
linear : math:`1`
conv{
1,2,3}d : math:`1`
sigmoid : math:`1`
tanh : math:`5 / 3`
relu : math:`\sqrt{
2}`
leaky_relu : math:`\sqrt{
2 / (1 + negative\_slope^2)}`
# ============ ==========================================
# 参数:
#nonlinearity: the nonlinear function (`nn.functional`)
#param: optional parameter for the nonlinear function
#Examples:
# >>> gain = nn.init.calculate_gain('leaky_relu')
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
#默认值为0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) #返回非线性函数增益不支持
def uniform(tensor, a=0, b=1):
#给Tensor或者Variable填充值使其满足均匀分布U(a,b)
#参数:
#tensor: 一个待填充的 torch.Tensor or autograd.Variable
#a: 均匀分布下界
#b: 均匀分布上届
#Examples:
# >>> w = torch.Tensor(3, 5)
# >>> nn.init.uniform(w)
if isinstance(tensor, Variable):
#如果是Variable类型,给他的data tensor填充数值即可
uniform(tensor.data, a=a, b=b)