Pytorch参数初始化设置

在PyTorch中,如果不对网络参数进行显式初始化,各层会使用其默认的初始化方法。不同层类型的初始化策略有所不同,以下是常见层的默认初始化方式:

1. 全连接层 (nn.Linear)

权重初始化:使用Kaiming均匀分布(He初始化),假设激活函数为Leaky ReLU(负斜率a=sqrt(5))。初始化范围根据输入维度(fan_in)计算,公式为:

bound = 1 fan_in \text{bound} = \frac{1}{\sqrt{\text{fan\_in}}} bound=fan_in 1

权重从均匀分布 ( U(-\text{bound}, \text{bound}) ) 中采样。
偏置初始化:均匀分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。

2. 卷积层 (nn.Conv2d, nn.Conv1d, nn.Conv3d)

权重初始化:与全连接层类似,使用Kaiming均匀分布,但fan_in计算为输入通道数乘以卷积核面积(例如,对于Conv2dfan_in = in_channels * kernel_height * kernel_width)。
偏置初始化:与全连接层相同,均匀分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。

3. LSTM/GRU层 (nn.LSTM, nn.GRU)

权重初始化:权重从均匀分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
偏置初始化:偏置分为两部分,一部分初始化为零,另一部分从均匀分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。

4. 批归一化层 (nn.BatchNorm1d, nn.BatchNorm2d)

缩放参数(weight):初始化为1。
偏移参数(bias):初始化为0。

5. 嵌入层 (nn.Embedding)

权重初始化:从正态分布 ( N(0, 1) ) 中采样。

默认初始化的潜在问题

激活函数不匹配:Kaiming初始化默认假设使用Leaky ReLU(a=sqrt(5)),若使用ReLU或其他激活函数,可能需要手动调整初始化方式以避免梯度不稳定。
深层网络训练:默认初始化在较浅网络中表现良好,但在深层网络中可能需要更精细的初始化(如Xavier或正交初始化)。

代码示例:查看默认初始化

import torch.nn as nn

# 定义层
linear = nn.Linear(100, 50)
conv = nn.Conv2d(3, 16, kernel_size=3)
lstm = nn.LSTM(input_size=10, hidden_size=20)

# 打印权重范围和标准差
def print_init_info(module):
    for name, param in module.named_parameters():
        if 'weight' in name:
            print(f"{name} mean: {param.data.mean():.4f}, std: {param.data.std():.4f}, range: [{param.data.min():.4f}, {param.data.max():.4f}]")
        elif 'bias' in name:
            print(f"{name} mean: {param.data.mean():.4f}")

print("Linear层初始化信息:")
print_init_info(linear)

print("\nConv2d层初始化信息:")
print_init_info(conv)

print("\nLSTM层初始化信息:")
print_init_info(lstm)

手动初始化推荐

若默认初始化不适用,可手动初始化以适配激活函数:

# 针对ReLU的Kaiming初始化
for module in model.modules():
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.LSTM):
        for name, param in module.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

司南锤

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值