在深度学习中常见的初始化操作

目录

截断正态分布来初始化张量

逐行代码解释

相关理论解释

截断正态分布函数

截断正态分布的定义

截断正态分布的作用

计算截断点的作用

具体步骤

正态分布的累积分布函数(CDF)

 正态分布的累积分布函数与误差函数的关系

示例计算

误差函数

应用:

定义:

误差函数的性质

Python 中的误差函数

总结


截断正态分布来初始化张量

import math
import warnings
import torch

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor

逐行代码解释

1、正态分布的累积分布函数(CDF)norm_cdf 函数计算标准正态分布的累积分布函数。

def norm_cdf(x):
    return (1. + math.erf(x / math.sqrt(2.))) / 2.

2、警告:检查均值是否在截断边界 [a, b] 的2个标准差范围内,如果不在,则发出警告。

if (mean < a - 2 * std) or (mean > b + 2 * std):
    warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                  "The distribution of values may be incorrect.",
                  stacklevel=2)

3、不跟踪梯度:以下代码块确保初始化时不跟踪梯度,这对于设置神经网络的初始权重很有用。

with torch.no_grad():
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)
    tensor.uniform_(2 * l - 1, 2 * u - 1)
    tensor.erfinv_()
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)
    tensor.clamp_(min=a, max=b)
    return tensor
  • lu 是截断点 ab 处的累积分布函数值。
  • tensor.uniform_(2 * l - 1, 2 * u - 1) 用从指定范围的均匀分布生成的值初始化张量。
  • tensor.erfinv_() 对张量应用误差函数的逆函数。
  • tensor.mul_(std * math.sqrt(2.)) 将张量的值缩放到期望的标准差。
  • tensor.add_(mean) 将张量的值平移到期望的均值。
  • tensor.clamp_(min=a, max=b)</
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值