【pytorch 记录】pytorch的变量parameter_buffer、self.register_buffer()、self.register_parameter()

本文详细介绍了PyTorch中模型参数的管理方法,包括parameter和buffer的区别,如何使用register_buffer和register_parameter注册参数,以及如何通过state_dict保存和加载模型。
部署运行你感兴趣的模型镜像

在pytorch中模型需要保存下来的参数包括:

  • parameter:反向传播需要被 optimizer 更新的,可以被训练。
  • buffer:反向传播不需要被 optimizer 更新,不可被训练。

这两种参数都会分别保存到 一个OrderDict 的变量中,最终由 model.state_dict() 返回进行保存。

1 nn.Module的介绍

需要先说明下:直接torch.randn(1, 2) 这种定义的变量,没有绑定在pytorch的网络中,训练结束后也就没有在保存在模型中。当我们想要将一些变量保存(如yolov5中的anchor),可以用作简单的后处理,就需要将这种变量注册到网络中,可以使用的api为:self.register_buffer() :不可被训练;self.register_parameter()nn.parameter.Parameter()nn.Parameter():可以被训练。

对于pytorch定义网络时,都要继承与 nn.Module。到源码中看到该类的初始化中,成员变量如下,这里我们关心是绿色选中区域,这三个成员都是 OrderedDict() 类型的
在这里插入图片描述

成员变量:

  • _buffers:由self.register_buffer() 定义,requires_grad默认为False,不可被训练。
  • _parasmeter:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定义的变量都存放在该属性下,且定义的参数的 requires_grad 默认为 True。
  • _module:nn.Sequential()、nn.conv() 等定义的网络结构中的结构存放在该属性下。

成员函数:

  • self.state_dict():OrderedDict 类型。保存神经网络的推理参数,包括parameter、buffer
  • self.name_parameters():为迭代器。self._moduleself._parameters中所有的可训练参数的名字+tensor。包括 BN的 bn.weight、bn.bias。
  • self.parameters():与self.name_parameters()一样,但不包含名字
  • self.name_buffers():为迭代器。网络中所有的不可训练参数和自己注册的buffer 中的参数的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。
  • self.buffers():与self.name_buffers()一样,但不包含名字
  • net.named_modules():为迭代器。self._module中定义的网络结构的名字+层
  • net.modules()

2 代码示例

import torch  
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        """=======case1: self._modules======="""
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.TEST_1 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        """=======case2: self._buffers======="""
        self.register_buffer('TEST_2', torch.randn(1, 3))

        """=======case3: model._parameters======="""
        self.register_parameter('TEST_30', nn.Parameter(torch.randn(1, 4)))
        self.TEST_31 = nn.parameter.Parameter(torch.tensor(1.0))
        self.TEST_32 = nn.Parameter(torch.tensor(2.0))

        """=======case4======="""
        self.TEST_4 = torch.randn(1, 2)

    def forward(self, x):
       return x

model = Model() 
print()
print(f'=========================================model._modules:\n{model._modules}\n') 
print(f'=========================================model._buffers:\n{model._buffers}\n') 
print(f'=========================================model._parameters:\n{model._parameters}\n')
print(f'=========================================model.state_dict():\n{model.state_dict()}\n')

其实debug方式查看会更便捷。直接打印也没有问题。
在这里插入图片描述

如果要打印介绍的成员函数的内容,则有:

named_buffers = [param for param in model.named_buffers()]
print(f'===================================named_buffers:\n{named_buffers}\n')

named_parameters = [param for param in model.named_parameters()]
print(f'===================================named_parameters:\n{named_parameters}\n')

named_modules = [param for param in model.named_modules()]
print(f'===================================named_modules:\n{named_modules}\n')

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think> 我们正在讨论PyTorch中的`self.register_buffer("pe", pe)`这行代码。 根据PyTorch的文档,`register_buffer`是`torch.nn.Module`的一个方法,用于注册一个缓冲区(buffer)。 缓冲区是模型的一部分,会被包含在`state_dict`中(因此会被保存和加载),但不会被视作模型的参数(即不会在训练中被优化器更新)。 具体分析: 1. 含义:将张量`pe`注册为模块的缓冲区,命名为"pe"。 2. 作用: - 自动将缓冲区转移到模型所在的设备(当调用`model.to(device)`时)。 - 在保存模型时,缓冲区会被包含在状态字典中(state_dict),从而被保存。 - 在加载模型时,缓冲区会从状态字典中恢复。 - 缓冲区不会参与梯度计算(即不会自动求导),因此适合存储不需要训练的参数(如预定义的常量、位置编码等)。 3. 使用场景: - 在Transformer中,位置编码(Positional Encoding)通常是固定的(非学习的),因此使用缓冲区存储。 - 其他需要存储但不需要训练的张量,例如:预计算的常数、掩码(mask)等。 示例:在Transformer的位置编码中,我们可能会这样使用: ```python import torch from torch import nn class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) # 创建一个位置编码矩阵 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) # 调整形状 self.register_buffer('pe', pe) # 注册为缓冲区 def forward(self, x): x = x + self.pe[:x.size(0), :] # 添加位置编码 return x ``` 在这个例子中,位置编码`pe`是一个固定的张量,不需要训练,所以注册为缓冲区。当模型被移动到GPU时,这个`pe`也会自动移动到GPU;保存模型时,它会被保存;加载模型时,它会被加载。 与`nn.Parameter`的区别: - `nn.Parameter`是会被优化器训练的参数,而缓冲区不会。 - 使用`self.pe = nn.Parameter(pe, requires_grad=False)`也可以实现设备转移和保存加载,但是会作为参数(虽然不更新)出现在模型的`parameters()`迭代器中。而缓冲区不会出现在`parameters()`中,但会出现在`named_buffers()`中。 因此,使用`register_buffer`更合适,因为它明确表示了这不是可训练参数。 总结: `self.register_buffer("pe", pe)`用于将张量`pe`注册为模型的缓冲区,使其能够随模型一起保存、加载和转移设备,但不会参与训练。 相关问题: 1.PyTorch中,缓冲区和参数有什么区别? 2. 如何访问模型中的缓冲区? 3. 在模型的保存和加载过程中,缓冲区是如何处理的? 4. 除了位置编码,还有哪些典型场景会用到缓冲区? </think> 在PyTorch框架中,`self.register_buffer("pe", pe)` 是神经网络模块设计中的关键技术,用于管理不需要训练但需持久化保存的张量数据。以下是详细解析: ### 一、核心含义 1. **注册缓冲区**:将张量 `pe` 注册为模块的**持久化缓冲区**(buffer) 2. **非参数属性**:该张量属于模型状态的一部分,但**不会参与梯度计算**和优化器更新 3. **设备同步**:自动与模型其他部分保持相同设备(CPU/GPU) ### 二、核心作用 1. **设备自动迁移** 当执行 `model.to(device)` 时,注册的缓冲区自动迁移到指定设备: ```python model = PositionalEncoding(d_model=512) model.cuda() # 自动将pe转移到GPU ``` 2. **状态持久化** 缓冲区会被包含在模型状态字典中,随模型一起保存/加载: ```python torch.save(model.state_dict(), "model.pt") # 自动保存pe new_model.load_state_dict(torch.load("model.pt")) # 自动恢复pe ``` 3. **优化器隔离** 优化器不会处理缓冲区张量,避免不必要的计算: ```python for param in model.parameters(): print(param.requires_grad) # pe不会出现在这里 ``` ### 三、典型使用场景 1. **Transformer位置编码**(最常见场景) 固定位置编码矩阵不需要训练: ```python class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) # 创建位置编码矩阵 position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) # 注册为缓冲区 def forward(self, x): x = x + self.pe[:x.size(0)] # 添加位置信息 return x ``` 2. **预计算常量** 高频使用的固定计算量(如傅里叶变换基): ```python self.register_buffer('dct_matrix', compute_dct_matrix()) ``` 3. **运行时状态跟踪** 记录非训练相关状态(如EMA统计量): ```python self.register_buffer('moving_mean', torch.zeros(feature_dim)) ``` ### 四、与普通张量对比 | 特性 | `register_buffer` | 普通实例变量 | `nn.Parameter` | |---------------------|-----------------------|-------------------|--------------------| | 设备自动迁移 | ✅ | ❌ | ✅ | | 保存至state_dict | ✅ | ❌ | ✅ | | 优化器是否更新 | ❌ | ❌ | ✅ | | 是否计算梯度 | ❌ | ❌ | ✅ | | 序列化支持 | ✅ | ❌ | ✅ | ### 五、使用注意事项 1. **命名规范** 避免使用`_buffers`等PyTorch保留字作为名称 ```python self.register_buffer("_pe", pe) # 错误!双下划线为保留字段 ``` 2. **初始化时机** 必须在`__init__`中注册,不可在`forward`中动态注册: ```python class MyModule(nn.Module): def __init__(self): super().__init__() # 正确位置 self.register_buffer("mask", torch.eye(10)) def forward(self, x): # 错误!不能在forward中注册 self.register_buffer("temp", x) ``` 3. **访问方式** 通过模块属性直接访问: ```python print(model.pe) # 直接访问缓冲区 ``` > 该技术由PyTorch的`nn.Module`基类提供支持[^1],在Transformer等需要固定位置编码的模型中应用广泛[^3]。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值