torch module.apply(fn)说明

本文介绍了如何在PyTorch中使用nn.Module及其子类,并详细讲解了如何定义weight_init函数来初始化子模块的权重,以及如何通过apply()方法递归地应用该函数。通过实例展示,读者将学会如何正确地对神经网络模型进行权重初始化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

应用的对象应是nn.Module类型或者它的继承类。
可以有children。

比如

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
net = Model()

这里的net就是一个Module对象,它包含了children: conv1, conv2

fn用来给childer初始化weight,
比如fn可以这样定义:

def weight_init(m):
    """ Weight initialization function. """
    # Conv2D
    if isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    else:
        pass

用的时候Applies fn recursively to every submodule (as returned by .children()) as well as self.
可看到官方源码如下,先对children应用fn, 再对self应用,最后返回self

def apply(self: T, fn: Callable[['Module'], None]) -> T:
     for module in self.children():
         module.apply(fn)
     fn(self)
     return self

应用例:

net.apply(weight_init)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝羽飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值