『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

本文详细解析了PyTorch中nn.Module类的工作原理,包括__setattr__()方法如何处理参数和子模块的存储,以及__getattr__()方法的查询机制。通过示例展示了不同类型的属性如何被管理和访问。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、背景知识

python中两个属相相关方法

result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__getattr__('name')方法,再无则报错

obj.name = value 会调用builtin函数setattr(obj,'name',value)设置对应属性,如果设置了__setattr__('name',value)方法则优先调用此方法,而非直接将值存入__dict__并新建属性

二、nn.Module的__setattr__()方法逻辑

nn.Module中实现了__setattr__()方法,当再class的初始化__init__()中执行module.name=value时,会在其中判断value是否属于Parameters或者nn.Module对象,是则将之存储进入__dict__._parameters和__dict__._modules两个字典中;如果是其他对象诸如Variable、List、dict等等,则调用默认操作,将值直接存入__dict__中。

示例

nn.Module的新建Parameter属性,在._parameters中可以查询到,在.__dict__中没有,属于.__dict__._parameters中

import torch as t
import torch.nn as nn

module = nn.Module()
module.param = nn.Parameter(t.ones(2,2))

print(module._parameters)

"""
OrderedDict([('param', Parameter containing:
               1  1
               1  1
              [torch.FloatTensor of size 2x2])])
"""

print(module.__dict__)
"""
{'_backend': <torch.nn.backends.thnn.THNNFunctionBackend at 0x7f5dbcf8c160>,
 '_backward_hooks': OrderedDict(),
 '_buffers': OrderedDict(),
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_modules': OrderedDict(),
 '_parameters': OrderedDict([('param', Parameter containing:
                                          1  1
                                          1  1
                                         [torch.FloatTensor of size 2x2])]),
 'training': True}
"""

 

以通常List的格式传入的子Module直接从属于属于.__dict__,并未被_modules识别

submodule1 = nn.Linear(2,2)
submodule2 = nn.Linear(2,2)
module_list = [submodule1,submodule2]
module.submodules = module_list

print('_modules:',module_list)
# _modules: [Linear (2 -> 2), Linear (2 -> 2)]
print('__dict__[submodules]:',module.__dict__.get('submodules'))
# __dict__[submodules]: [Linear (2 -> 2), Linear (2 -> 2)]
print('__dict__[submodules]:',module.__dict__['submodules'])
# __dict__[submodules]: [Linear (2 -> 2), Linear (2 -> 2)]

 

以ModuleList格式传入的子Module可被._modules识别,而不直接从属于.__dict__

module_list = nn.ModuleList(module_list)
module.submodules = module_list

print(isinstance(module_list,nn.Module))
# True

print(module._modules)
"""
OrderedDict([('submodules', ModuleList (
  (0): Linear (2 -> 2)
  (1): Linear (2 -> 2)
))])
"""
print(module.__dict__.get('submodules'))
# None
print(module.__dict__['submodules'])
"""
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-19-d4344afabcbf> in <module>()
----> 1 print(module.__dict__['submodules'])

KeyError: 'submodules'
"""

三、属性查询函数__getattr__相关特性

nn.Module的.__getattr__()方法会对__dict__._module、__dict__._parameters和__dict__._buffers这三个字典中的key进行查询。当nn.Module进行属性查询时,会先在__dict__进行查询(仅查询本级),查询不到对应属性值时,就会调用.__getattr__()方法,再无结果就报错。

示例

对于__dict__中的属性.training,可以看到.__getattr__('training')查询时就没有结果,

print(module.__dict__.get('submodules'))
# None

getattr(module,'training')
# True

module.training
# True


module.__getattr__('training')
"""
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
……
AttributeError: 'Module' object has no attribute 'training'
"""

 

 另外,我们可以看到.__getattr__可以查询到的结果如下,都是nn.Module自建的属性,

module.__getattr__
"""
<bound method Module.__getattr__ of Module (
  (submodules): ModuleList (
    (0): Linear (2 -> 2)
    (1): Linear (2 -> 2)
  )
)>
"""

 

对于普通的新建属性,其实和nn.Module自建的没什么不同,不同查询方式输出相似,

module.attr1 = 2
getattr(module,'attr1')
# 2

module.__getattr__('attr1')
"""
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
……
AttributeError: 'Module' object has no attribute 'attr1'
"""

 

对于nn.Module的特殊属性,可以看到,getattr和.__getattr__均可查到,这也是由于getattr一次查找无果后,调用.__getattr__的结果,

getattr(module,'param')
"""
Parameter containing:
 1  1
 1  1
[torch.FloatTensor of size 2x2]
"""

module.__getattr__('param')
"""
Parameter containing:
 1  1
 1  1
[torch.FloatTensor of size 2x2]
"""

 

### PyTorch 中 `torch.nn.utils.weight_norm` 的弃用情况 在 PyTorch 的开发历程中,某些功能可能会被标记为过时并最终移除。关于 `torch.nn.utils.weight_norm` 的具体弃用版本信息,在官方文档和社区讨论中并未明确指出其完全废弃的时间点[^1]。然而,可以确认的是,自 PyTorch 1.7 版本起,`weight_norm` 功能已被推荐使用更灵活的替代方案&mdash;&mdash;通过 `nn.Parameter` 或者其他模块化设计来实现权重标准化的效果。 #### 替代方案 对于需要实现类似 `weight_norm` 功能的需求,开发者通常会采用以下几种方式: 1. **手动实现权重规范化** 可以通过定义新的参数并将原始权重分解为其大小和方向部分来模拟 `weight_norm` 的行为。例如: ```python import torch import torch.nn as nn class WeightNormLinear(nn.Module): def __init__(self, in_features, out_features): super(WeightNormLinear, self).__init__() self.weight_v = nn.Parameter(torch.randn(out_features, in_features)) self.weight_g = nn.Parameter(torch.ones(out_features)) self.bias = nn.Parameter(torch.zeros(out_features)) def forward(self, x): weight_normalized = self.weight_g.unsqueeze(-1) * (self.weight_v / torch.norm(self.weight_v, dim=-1, keepdim=True)) return torch.matmul(x, weight_normalized.t()) + self.bias ``` 2. **利用 Spectral Normalization** 如果目标是为了控制模型的 Lipschitz 常数,可以考虑使用谱归一化(Spectral Normalization),它可以通过 `torch.nn.utils.spectral_norm` 实现类似的约束效果[^4]。 3. **基于最新 API 进行重构** 自 PyTorch 高级优化器引入以来,许多旧的功能逐渐被淘汰。因此建议重新评估需求,并尝试借助更高层次的设计模式完成相同的目标。 需要注意的是,尽管目前尚未有确切记录表明某个特定版本正式删除了该函数,但在实际项目维护过程中应尽量避免依赖已标注为 deprecated 的特性[^5]。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值