PyTorch 多 GPU 训练保存模型权重 key 多了 ‘module.‘

在多GPU环境下使用PyTorch训练模型并保存权重时,加载state_dict会遇到'模块名.module.'的问题,导致加载错误。本文分析了该问题的原因,并提供了两种解决方案:(1)加载权重后修改key;(2)保存权重前增加'module.'。根据情况选择合适的方法可解决模型加载问题。
部署运行你感兴趣的模型镜像

一、问题表现

使用多 GPU 训练保存模型权重后,再次加载 state_dict 会出现 ‘“Missing key(s)” 错误,信息如下,可以发现预期的权重 key 比文件中保存的 key 少了 'module.' 。或者说,在多 GPU 训练的情况下,通过 torch.save() 保存的模型权重的 key 多了 'module.'。

RuntimeError: Error(s) in loading state_dict for LeNet:
	Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "classifier.0.weight", "classifier.0.bias", "classifier.2.weight", "classifier.2.bias". 
	Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.classifier.0.weight", "module.classifier.0.bias", "module.classifier.2.weight", "module.classifier.2.bias". 
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 24763 closing signal SIGTERM

具体原因未知,官方文档也未给出更详细的例子,只有一般用法,如下。

import torch
import torchvision.models as models

# PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

可能是因为执行多 GPU 训练时,使用官方推荐的 python -m torch.distributed.lauchtorchrun 工具所致。

二、解决方法

针对以上问题有两种解决方法,可以分为加载权重后保存权重前

(1)加载权重后修改 key

首先通过 torch.load() 加载权重文件,然后遍历字典,如果 key 中包含 'module' 则将其删掉,参考这里

weights_name = 'weights-ep10-1641471178.0502117.rank-0.pth'  
weights = torch.load(weights_name)

weights_dict = {}
for k, v in weights.items():
    new_k = k.replace('module.', '') if 'module' in k else k
    weights_dict[new_k] = v

model.load_state_dict(weights_dict)

(2)保存权重前增加 module

使用 torch.save() 保存权重时,通过 model.module.state_dict() 获取模型权重,而不是像官方示例中只用 model.state_dict() ,参考这里

model_weights_name = "weights.pth"
torch.save(model.module.state_dict(), model_weights_name)

注意在多 GPU 训练情况下才会出现保存模型权重的 key 多了 'module.',以上两种方法选择其中一种即可,例如,当已经拿到多 GPU 训练的模型时,使用方法(1)比较好;如果重新训练模型,则可以直接使用方法(2)。

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think>我们正在解决一个具体问题:使用FSDP训练模型保存权重后重新加载却无法转换为ONNX格式。根据引用[2],FSDP(Fully Sharded Data Parallel)是一种分布式训练技术,它将模型参数、梯度和优化器状态分片到GPU上,以支持大模型训练。 问题分析: 1. FSDP训练后的模型保存时,由于模型被分片,每个进程可能只保存模型的一部分(shard)。 2. 当我们尝试加载权重时,如果没有正确的聚合所有分片,模型将是不完整的。 3. 转换ONNX需要完整的模型结构,因此直接加载单个分片会导致失败。 解决方案: 我们需要在保存模型之前,将模型聚合到单个进程(rank0),然后由rank0保存完整的模型。这样重新加载时就是一个完整的模型,可以顺利转换ONNX。 步骤: 1.训练结束后,使用FSDP提供的`consolidate_sharded_weights`(或类似功能)将分片的权重合并。 2. 仅由rank0进程保存完整模型。 3. 加载时,使用单个进程加载完整模型,然后进行ONNX转换。 具体代码实现(基于PyTorch 1.12+): ```python from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType, FullStateDictConfig # 假设我们的模型model,已经用FSDP包装 # 训练完成后,保存模型 # 设置状态字典类型为全状态(完整模型) full_state_dict_config = FullStateDictConfig(rank0_only=True, offload_to_cpu=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config): # 此时,只有rank0会得到完整的模型状态字典,其他rank得到空字典 state_dict = model.state_dict() if dist.get_rank() == 0: # 保存完整模型,注意:这里保存的state_dict已经是完整的 torch.save(state_dict, "full_model.pth") ``` 重新加载模型并转换为ONNX: 注意:加载模型和转换ONNX必须在非分布式环境下进行(或者单个进程)。 ```python # 在单个GPU或CPU上 model = MyLargeModel() # 创建与原始模型相同的非FSDP模型 state_dict = torch.load("full_model.pth") model.load_state_dict(state_dict) model.eval() # 转换为ONNX dummy_input = torch.randn(1, 3, 224, 224) # 根据模型输入调整 torch.onnx.export(model, dummy_input, "model.onnx") ``` 可能遇到的坑: 1. 如果原始模型使用了混合精度(如FP16),在保存全精度模型时,需要将模型转换为FP32再保存,或者确保ONNX转换时使用相同精度。 2. 如果模型结构在FSDP包装后有所变化(比如被包装在FSDP模块内),在加载时需要创建原始模型结构(没有FSDP包装的)。 补充说明: 如果使用PyTorch版本较低(低于1.12),可能需要使用其他方法,如`fsdp_model.module.state_dict()`(但注意这可能会在分片模型上失败)或者自定义收集函数。 另一种方法是使用`torch.save(fsdp_model.state_dict(), ...)`,但这样保存的是分片的状态字典,每个进程保存自己的分片。这样就需要在加载时重新用FSDP包装模型,然后加载状态字典,但是这样得到的模型仍然是分片的,不适合直接转换ONNX。 因此,推荐使用上述的`StateDictType.FULL_STATE_DICT`方法。 如果问题仍然存在,考虑以下可能: - 模型结构在保存和加载时不一致(如训练时有FSDP包装,而加载时没有,或者反之)。确保加载时创建的是非FSDP模型的实例。 - 检查模型结构定义是否一致(如使用了条件分支,导致结构变化)。 引用说明: 根据引用[2],FSDP支持模型分片,因此我们在保存时需要将分片合并。引用[3]提到训练是一个循环迭代过程,我们在训练完成后进行保存。引用[4]提到大模型训练中的梯度问题,但与本问题关联不大。 相关问题: 1. 如何在使用FSDP分布式训练保存检查点(checkpoint)并恢复训练? 2. 使用FSDP训练模型能否直接部署在单个GPU上? 3. 如何将FSDP训练模型转换为TensorRT格式? 注意:以上解决方案假设用户使用的是PyTorch 1.12及以上版本。对于更早版本,可能需要使用`model.module.state_dict()`(如果FSDP模型有`module`属性)或者自定义收集函数。 </think> ### 解决 FSDP 模型保存后无法转 ONNX 的问题 在使用 PyTorch `torch.distributed.fsdp` 训练模型转换 ONNX 时,主要问题是 FSDP 的分片机制导致模型状态不完整。以下是系统化的解决方案: #### 一、核心问题分析 1. **分片权重问题** FSDP 将模型参数分片存储在不同设备上,直接保存的 checkpoint 不包含完整模型参数[^2] 2. **元数据缺失** ONNX 转换需要完整的计算图结构,但 FSDP 的分布式包装层破坏了原始模型结构 3. **设备不一致** 保存的 checkpoint 包含设备信息(如 `cuda:0`),导致加载时出现设备不匹配错误 --- #### 二、完整解决方案 ##### 步骤 1:保存完整模型状态 使用 FSDP 提供的官方方法合并分片权重: ```python from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType # 配置全量状态字典(完整模型) fsdp_state_dict_cfg = FSDP.FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fsdp_state_dict_cfg): full_state_dict = model.state_dict() if dist.get_rank() == 0: # 仅 rank0 保存 torch.save({ 'model': full_state_dict, 'model_config': model.config # 保存原始模型配置 }, "full_model.pth") ``` ##### 步骤 2:重建原始模型结构 加载时使用原始模型类(非 FSDP 包装版本): ```python # 重建原始模型(非FSDP包装) original_model = MyModelClass(**model_config) # 加载完整状态字典 checkpoint = torch.load("full_model.pth", map_location='cpu') original_model.load_state_dict(checkpoint['model']) # 关键:移除所有分布式相关属性 for module in original_model.modules(): if hasattr(module, '_fsdp_wrapped_module'): delattr(module, '_fsdp_wrapped_module') ``` ##### 步骤 3:ONNX 转换最佳实践 ```python original_model.eval().cpu() # 确保在CPU模式 # 动态轴配置(处理batch/seq长度变化) dynamic_axes = { 'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'} } # 导出ONNX torch.onnx.export( original_model, dummy_input.cpu(), # 确保输入在CPU "model.onnx", opset_version=13, # 推荐使用最新opset do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes ) ``` --- #### 三、常见错误处理 | 错误类型 | 解决方案 | |---------|----------| | `KeyError: 'module._fsdp_...'` | 使用 `FullStateDictConfig` 保存完整状态字典 | | `RuntimeError: tensor on meta device` | 加载前执行 `model.to('cpu')` | | `Unsupported ONNX opset: FSDPWrapper` | 转换前移除所有分布式包装层 | | `Input type mismatch` | 确保 dummy_input 与模型训练时格式一致 | --- #### 四、高级调试技巧 1. **模型结构验证** 导出前检查模型结构完整性: ```python # 检查是否存在FSDP残留 for name, module in original_model.named_modules(): if 'fsdp' in name.lower(): raise RuntimeError(f"FSDP残留模块: {name}") # 验证参数连续性 for param in original_model.parameters(): if param.device.type == 'meta': raise RuntimeError("存在meta device参数") ``` 2. **ONNX 简化工具** 转换后使用 ONNX Simplifier 优化: ```bash python -m onnxsim model.onnx simplified.onnx ``` 3. **可视化验证** 使用 Netron 工具检查导出的 ONNX 结构 ---
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值