【PyTorch】state_dict详解

本文介绍了PyTorch中state_dict的相关知识。torch.nn.Module模块的state_dict存放训练需学习的权重和偏执系数,映射成tensor张量,还可能包含batchnorm相关参数;torch.optim模块的Optimizer对象的state_dict包含特定字典对象。因其是Python字典,便于操作,最后给出简单案例输出其变量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文章已经生成可运行项目,

Introduce

在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.youkuaiyun.com/wzy_zju/article/details/81262453

torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。

因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

Sample

通过一个简单的案例来输出state_dict字典对象中存放的变量

#encoding:utf-8

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F

#define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.pool=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)

    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

def main():
    # Initialize model
    model = TheModelClass()

    #Initialize optimizer
    optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

    #print model's state_dict
    print('Model.state_dict:')
    for param_tensor in model.state_dict():
        #打印 key value字典
        print(param_tensor,'\t',model.state_dict()[param_tensor].size())

    #print optimizer's state_dict
    print('Optimizer,s state_dict:')
    for var_name in optimizer.state_dict():
        print(var_name,'\t',optimizer.state_dict()[var_name])



if __name__=='__main__':
    main()

 具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值

Model.state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer,s state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]

 

本文章已经生成可运行项目
<think>我们正在处理的是PyTorch模型参数的加载问题。用户提供的代码片段是:{k:vfork,vinsave_model.items()ifkinmodel_dict.keys()}这个字典推导式的作用是:从保存的模型状态字典(save_model)中筛选出当前模型(model_dict)中存在的键值对。在PyTorch中,`model.state_dict()`返回一个有序字典(OrderedDict),包含模型的所有可学习参数(权重和偏置)以及注册的缓冲区(如BatchNorm的running_mean)。使用`load_state_dict`加载参数时,要求保存的状态字典的键必须与当前模型的键完全匹配(除非设置了`strict=False`)。###代码功能详解1.**过滤参数**:该推导式遍历保存的模型状态字典(`save_model.items()`),只保留那些键存在于当前模型状态字典(`model_dict.keys()`)中的项。2.**应用场景**:当保存的模型和当前模型结构不完全一致时(例如,加载预训练模型但当前模型缺少某些层),通过这种方式可以避免因键不匹配而导致的错误。###在PyTorch中的使用示例```python#假设我们有一个保存的模型状态字典(来自文件或网络)saved_state_dict=torch.load('model.pth')#类型为OrderedDict或普通dict#获取当前模型的状态字典(作为键的参照)model=MyModel()#假设已经定义了模型类current_state_dict=model.state_dict()#当前模型的键集合#使用字典推导式过滤出当前模型存在的参数filtered_state_dict={k:vfork,vinsaved_state_dict.items()ifkincurrent_state_dict}#加载过滤后的参数model.load_state_dict(filtered_state_dict,strict=False)```###注意事项1.**键匹配**:必须确保保存的模型和当前模型的对应层有相同的键名。如果层名不一致,需要手动重命名或映射。2.**严格模式**:使用`strict=False`允许部分加载,即忽略不匹配的键(多余或缺失)。但需注意:-未加载的参数将保持随机初始化状态-缺失的键不会触发错误,但可能影响模型性能3.**数据类型一致性**:保存的参数和当前模型参数的数据类型必须兼容。例如,不能将`FloatTensor`加载到`IntTensor`层(如引用[4]中的问题)[^4]。4.**缓冲区(Buffer)**:除了可学习参数,状态字典还可能包含缓冲区(如BatchNorm的统计量)。这些也需要匹配。###常见错误及解决方案1.**键不匹配错误(KeyError)**:-原因:保存的模型有当前模型不存在的键。-解决:使用上述推导式过滤,或设置`strict=False`。2.**数据类型不匹配**:-现象:如引用[4]中,注册的缓冲区(buffer)在保存时被替换为浮点张量,但新模型期望整数[^4]。-解决:确保参数数据类型一致,避免直接替换模型属性(应使用`register_buffer`正确注册)。3.**形状不匹配**:-原因:保存的参数张量与当前模型对应层的形状不同。-解决:手动调整参数形状(如通过`v.view_as(current_state_dict[k])`)或跳过该参数。###进阶技巧-**参数重映射**:当键名不同但参数可对应时,构建映射字典:```pythonkey_mapping={"old_layer.weight":"new_layer.weight",...}remapped_state_dict={key_mapping.get(k,k):vfork,vinsaved_state_dict.items()}```-**部分初始化**:仅加载部分层的参数(如卷积层),其余层保持初始化:```python#只加载以"conv"开头的层filtered={k:vfork,vinsaved_state_dict.items()ifk.startswith("conv")andkincurrent_state_dict}```###总结该字典推导式是PyTorch模型迁移学习中常用的参数过滤手段,核心在于通过键名交集实现安全加载。配合`load_state_dict`的`strict=False`选项,可有效处理模型结构差异问题[^1][^2]。§§相关问题§§1.如何解决PyTorch模型加载时的键名不匹配问题?2.在模型微调(fine-tuning)中,如何处理部分层参数加载?3.PyTorch的`strict=False`模式有哪些潜在风险?4.如何正确注册和加载缓冲区(Buffer)参数?[^4]5.在分布式训练中加载模型参数有哪些特殊注意事项?</think>### 理解代码功能与 PyTorch 参数加载 您提供的代码片段: ```python {k: v for k, v in save_model.items() if k in model_dict.keys()} ``` 这是一个**字典推导式**,用于从保存的模型状态字典中筛选出当前模型存在的键值对。功能相当于: ```python filtered_dict = {} for k, v in save_model.items(): if k in model_dict.keys(): filtered_dict[k] = v ``` #### 在 PyTorch 参数加载中的核心作用: 1. **键匹配**:确保只加载当前模型结构中存在的参数 2. **安全过滤**:避免加载多余参数导致`KeyError` 3. **模型兼容**:处理不同版本模型间的参数差异 ### 正确使用 `load_state_dict` 的完整流程 ```python # 1. 加载保存的模型参数 save_model = torch.load('model.pth') # 类型为OrderedDictdict [^3] # 2. 获取当前模型的状态字典 model_dict = model.state_dict() # 3. 过滤匹配的参数 (您的代码片段) matched_params = {k: v for k, v in save_model.items() if k in model_dict.keys()} # 4. 更新当前模型参数 model_dict.update(matched_params) # 更新字典值 # 5. 加载到模型 model.load_state_dict(model_dict) # 严格模式加载 [^1] ``` ### 关键注意事项与常见错误解决 #### 1. 键不匹配问题 - **错误现象**:`RuntimeError: Missing key(s) in state_dict...` - **解决方案**: ```python # 使用strict=False允许部分加载 model.load_state_dict(matched_params, strict=False) ``` #### 2. 数据类型不一致 - **错误现象**:如引用[4]中的整数张量 vs 浮点张量问题 [^4] - **解决方案**: ```python # 显式转换数据类型 matched_params = {k: v.to(model_dict[k].dtype) for k, v in save_model.items() if k in model_dict.keys()} ``` #### 3. 模型结构变化处理 当添加/删除网络层时: ```python # 仅加载匹配参数,新层保持初始化 model.load_state_dict(matched_params, strict=False) # 获取未加载的键 missing, unexpected = model.load_state_dict(save_model, strict=False) print(f"未加载层: {missing}\n多余参数: {unexpected}") ``` #### 4. 设备位置不匹配 ```python # GPU模型加载CPU保存的参数 if next(iter(save_model.values())).device.type == 'cpu': matched_params = {k: v.to(device) for k, v in save_model.items() if k in model_dict.keys()} ``` ### 最佳实践建议 1. **版本兼容**:保存模型时同时存储结构元数据 ```python torch.save({ 'state_dict': model.state_dict(), 'model_version': '1.2', 'input_size': (3, 224, 224) }, 'model.pth') ``` 2. **参数验证**:加载后检查参数更新 ```python # 检查参数是否更新 for name, param in model.named_parameters(): if name in matched_params: print(f"{name}: {torch.allclose(param, matched_params[name])}") ``` 3. **缓冲区处理**:注册的缓冲区(如BN的running_mean)也需加载 [^4] ```python # 包含缓冲区过滤 all_keys = set(model.state_dict().keys()) filtered = {k: v for k, v in save_model.items() if k in all_keys} ``` ### 典型应用场景 ```mermaid graph TD A[加载预训练模型] --> B{是否完全匹配?} B -- 是 --> C[直接加载] B -- 否 --> D[使用键过滤] D --> E[部分参数初始化] E --> F[微调模型] ```
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值