torch.nn.init(nn/init.py)

torch.nn.init模块包含多种权重初始化方法,用于在定义网络后初始化权重,确保模型训练的效果。这些方法包括对2维张量赋予对角矩阵等,以保持输入特性。nn.Module子类虽部分初始化权重,但有时需自定义初始化,该模块提供此功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.nn.init(nn/init.py)

该文件定义了一些参数(权重)初始化方法
代码:
导入的包:

    import math
    import random
    import torch
    from torch.autograd import Variable
def calculate_gain(nonlinearity, param=None):
     #对于给定的非线性函数,返回推荐的增益值,用于后面的Xavier初始化
     #不同的增益值如下:
     # ============ ==========================================
          非线性函数                     增益值
     # ============ ==========================================
          linear :                     math:`1`
        conv{
  
  1,2,3}d :                 math:`1`
          sigmoid :                    math:`1`
           tanh :                    math:`5 / 3`
           relu :                  math:`\sqrt{
  
  2}`
        leaky_relu :  math:`\sqrt{
  
  2 / (1 + negative\_slope^2)}`
     # ============ ==========================================
     # 参数:
     #nonlinearity: the nonlinear function (`nn.functional`)
     #param:     optional parameter for the nonlinear function
     #Examples:
     # >>> gain = nn.init.calculate_gain('leaky_relu')
     linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
     if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
          return 1
     elif nonlinearity == 'tanh':
          return 5.0 / 3
     elif nonlinearity == 'relu':
          return math.sqrt(2.0)
     elif nonlinearity == 'leaky_relu':
          if param is None:
               negative_slope = 0.01                                                                                                     
               #默认值为0.01
          elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 
               negative_slope = param
          else:
               raise ValueError("negative_slope {} not a valid number".format(param))
          return math.sqrt(2.0 / (1 + negative_slope ** 2))
     else:
          raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))          #返回非线性函数增益不支持
def uniform(tensor, a=0, b=1):
     #给Tensor或者Variable填充值使其满足均匀分布U(a,b)
     #参数:
     #tensor: 一个待填充的 torch.Tensor or autograd.Variable
     #a: 均匀分布下界
     #b: 均匀分布上届
     #Examples:
     # >>> w = torch.Tensor(3, 5)
     # >>> nn.init.uniform(w)
     if isinstance(tensor, Variable):             
     #如果是Variable类型,给他的data tensor填充数值即可
          uniform(tensor.data, a=a, b=b)
    
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值