torch.nn.Module中__call__和forward方法

torch.nn.Module 重载了 Python 的 call 方法。
当你执行 net(data) 时,call 方法会被触发。
call 方法的主要作用是:
处理一些额外的逻辑(如 hooks 的执行)。
最终调用 forward 方法来执行模型的前向传播。

### torch.nn.utils._stateless 模块功能与使用 #### 功能概述 `torch.nn.utils._stateless` 是 PyTorch 中的一个内部模块,主要用于无状态地调用模型。该模块允许开发者在不改变原始模型参数的情况下执行前向传播其他操作。这有助于测试不同参数配置的效果而不影响实际模型的状态。 此模块特别适用于需要临时替换某些权重或缓冲区的情况,比如评估不同的初始化方案或是实现自定义优化算法时保持原有模型不变[^2]。 #### 主要方法及其作用 - `functional_call`: 这是最主要的方法之一,接受一个模型实例以及两组字典形式的输入——一组用于存储可训练参数(`parameters`);另一组则用来保存不可训练但又必需参与计算过程中的缓存数据(如BatchNorm里的running_meanrunning_var)(`buffers`)。通过这种方式可以在不影响原对象的前提下完成一次完整的预测流程。 ```python import torch from torch import nn from torch.nn.utils import stateless as sls class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 5) def forward(self, x): return self.linear(x) # 创建模型并获取其初始参数buffer副本 model = SimpleModel() params = {name: param.clone().detach() for name, param in model.named_parameters()} buffers = {name: buffer.clone().detach() for name, buffer in model.named_buffers()} # 准备一些虚拟的数据作为输入 input_data = torch.randn((8, 10)) # 使用_stateless.functional_call来运行这个简单的线性变换 output = sls.functional_call(model, (params, buffers), input_data) print(output.shape) # 应输出 torch.Size([8, 5]) ``` 上述代码展示了如何利用 `_stateless.functional_call` 方法来进行一次独立于现有模型状态之外的操作。这里创建了一个小型神经网络,并复制了它的所有参数缓冲区到新的张量中去。接着就能够在这些新构建出来的张量上进行任意修改而不用担心会影响到原来的模型结构。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值