目录
截断正态分布来初始化张量
import math
import warnings
import torch
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
逐行代码解释
1、正态分布的累积分布函数(CDF):norm_cdf
函数计算标准正态分布的累积分布函数。
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
2、警告:检查均值是否在截断边界 [a, b]
的2个标准差范围内,如果不在,则发出警告。
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
3、不跟踪梯度:以下代码块确保初始化时不跟踪梯度,这对于设置神经网络的初始权重很有用。
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
l
和u
是截断点a
和b
处的累积分布函数值。tensor.uniform_(2 * l - 1, 2 * u - 1)
用从指定范围的均匀分布生成的值初始化张量。tensor.erfinv_()
对张量应用误差函数的逆函数。tensor.mul_(std * math.sqrt(2.))
将张量的值缩放到期望的标准差。tensor.add_(mean)
将张量的值平移到期望的均值。tensor.clamp_(min=a, max=b)
</