Pytorch DDP DistributedDataParallel load_state_dict checkpoint 加载断点 继续训练的正确打开方式

一、网上的一些博客是错误的,我按照他们的方法去改,结果不对,比如:

Pytorch采坑记录:DDP加载之前的checkpoint后loss上升(metric下降)
[原创][深度][PyTorch] DDP系列第二篇:实现原理与源代码解析

二、pytorch官网的说法也是让人摸不着头脑,官方说法,但是保存checkpoint 的做法是对的,直接在主进程共保存即可,这点不再赘述,但是加载断点文件到DDP中的说法不是很清楚。特别是map_location = {‘cuda:%d’ % 0: ‘cuda:%d’ % rank} 很容易让人误解。

三、正确的方法:

1、在你的Model()类初始化和DistributedDataParallel()类初始化中间加载断点:在每个进程运行到这里时都会从硬盘中读取文件到CPU,然后再到每一个显卡,两种方式都可。主要是设置map_location参数,否则会出现如官网描述的错误,主显卡显存略大,多出来几个进程等等。这都是不是正确的打开方式。。。

在这里插入图片描述

2、1为何要去掉module.?因为DistributedDataParallel之前的model的state_dict是单卡的,之后他就会加上module.,如果按下面的保存方式就需要在加载断点之前去掉。当然你也可以保存为单卡的state_dict:orch.save(model.module.state_dict(), pth) (可能不对,网上很多教程,自行查找)。另外如果你在DistributedDataParallel()之后,开始第一个训练的epoch时加载模型参数,那么也不需要去掉module.

在这里插入图片描述

3、加载和保存优化器?这里我感觉很复杂,我没有使用保存优化器,我只是使用保存的断点做预训练和测试,所以不需要优化器断点训练。有懂的大佬可以留言!
4、DDP很难用,各种bug和小技巧,主要是文档写的不好,ε=(´ο`*)))唉
5、请不要忘了看看国外网友的参考,写的都很好:

https://discuss.pytorch.org/t/what-is-the-proper-way-to-checkpoint-during-training-when-using-distributed-data-parallel-ddp-in-pytorch/139575

https://github.com/pytorch/pytorch/issues/23138

<think>好的,用户问的是如何在PyTorch加载模型,需要示例代码和最佳实践。首先,我需要回想PyTorch的模型保存和加载机制。记得有两种主要方式:保存整个模型和仅保存状态字典。保存整个模型可能更方便,但占用空间大,且不利于跨项目使用。状态字典更灵活,推荐的做法是保存状态字典,然后在加载时重新初始化模型结构,再加载参数。 接下来,用户可能需要示例代码。我应该分别展示两种保存方式对应的加载方法。比如,使用torch.save(model.state_dict(), 'model.pth')的话,加载时需要先创建模型实例,再用load_state_dict。而如果保存了整个模型,直接加载即可,但要注意模型类必须在当前作用域中定义过。 另外,用户可能关心设备的问题,比如加载到CPU或GPU。这时候需要解释map_location参数的使用,比如在加载时指定device,或者自动处理。比如从GPU保存的模型加载到CPU上,就需要调整。 还有可能涉及的是多GPU训练后的加载,这时候可能需要处理模型被包装在DataParallel或DistributedDataParallel中的情况。应该提到加载时要去掉module.前缀,或者使用strict=False参数来避免键不匹配的错误。 最佳实践部分,应该强调始终保存状态字典而非整个模型,确保模型结构在加载时可重建,处理设备映射,版本兼容性问题,以及加载后调用model.eval()进入推理模式。 可能还需要提到验证加载后的模型,比如检查参数是否正确,或者进行前向传播测试。另外,如果用户需要加载训练模型并修改部分层,应该如何处理,比如忽略某些键或加载部分权重。 最后,用户可能需要了解常见的错误和解决方法,比如键不匹配、设备不匹配、版本不兼容等。这时候给出示例代码和解释,帮助他们避免这些问题。 总结下来,回答的结构应该是先介绍两种加载方法,给出代码示例,然后讨论设备处理,多GPU情况,最佳实践,常见问题处理,以及验证步骤。确保用户能够一步步按照指导操作,并理解背后的原理。</think>以下是关于如何在PyTorch加载模型的详细说明,包含示例代码和最佳实践: --- ### 一、基础方法:加载模型 #### 1. 仅加载模型参数(推荐方式) ```python import torch import torch.nn as nn # 定义原始模型结构 class MyModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 2) def forward(self, x): return self.fc(x) # 实例化模型并加载参数 model = MyModel() model.load_state_dict(torch.load('model_weights.pth')) model.eval() # 设置为评估模式 ``` #### 2. 加载完整模型(需保留类定义) ```python model = torch.load('full_model.pth') model.eval() ``` --- ### 二、设备处理技巧 #### 1. GPU模型加载到CPU ```python model.load_state_dict(torch.load('gpu_model.pth', map_location=torch.device('cpu'))) ``` #### 2. 自动选择可用设备 ```python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict(torch.load('model.pth', map_location=device)) ``` --- ### 三、多GPU训练模型的加载 处理由`DataParallel`包装的模型: ```python # 原始保存方式(包含module.前缀) torch.save(model.module.state_dict(), 'ddp_model.pth') # 加载时移除前缀 state_dict = torch.load('ddp_model.pth') new_state_dict = {k.replace('module.', ''): v for k,v in state_dict.items()} model.load_state_dict(new_state_dict) ``` --- ### 四、最佳实践 1. **参数验证**:加载后检查参数匹配性 ```python missing, unexpected = model.load_state_dict(torch.load('model.pth'), strict=False) print(f"Missing keys: {missing}\nUnexpected keys: {unexpected}") ``` 2. **版本控制**:同时保存模型元数据 ```python torch.save({ 'epoch': 200, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, 'checkpoint.pth') ``` 3. **部分加载**:选择性加载参数(适用于迁移学习) ```python pretrained_dict = torch.load('pretrained.pth') model_dict = model.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) ``` --- ### 五、常见问题解决方案 1. **设备不匹配错误**:使用`map_location`参数指定目标设备 2. **类定义丢失**:确保模型类定义在加载前已存在 3. **参数维度不匹配**:使用`strict=False`加载兼容参数 4. **CUDA内存不足**:先加载到CPU再转移至GPU ```python model.load_state_dict(torch.load('large_model.pth', map_location='cpu')) model.to('cuda') ``` --- ### 六、验证加载结果 ```python # 验证前向传播 test_input = torch.randn(1, 10) with torch.no_grad(): output = model(test_input) print(output.shape) # 应输出 torch.Size([1, 2]) # 检查第一层权重 print(model.fc.weight[0][:3]) # 应与保存前的数值一致 ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值