torch DDP训练-模型保存-加载问题

本文详细介绍了PyTorch中单GPU及DDP(Distributed Data Parallel)模式下模型的保存与加载方法,包括无module形式和module形式,并提供了在DDP模式下训练时遇到的问题及其解决办法。此外,还阐述了DDP训练的基本流程,包括初始化、使用DistributedSampler和分布式训练等关键步骤。

单GPU保存与加载:

模型保存:
####方法一
state = {'epoch': epoch,
         'model': model.state_dict(),
         'optimizer': optimizer.state_dict(),
         'scheduler': scheduler.state_dict()}
torch.save(state, 'model_path')

####方法二
torch.save(self.model.state_dict(), 'model_path')  # 当前目录

#######################################################################################


模型加载:
####方法一
self.model.load_state_dict(torch.load("model_path")['model'])

####方法二
self.model.load_state_dict(torch.load('model_path'))

DDP模式下模型保存与加载:

模型保存:
                    无module形式
####方法一
state = {'epoch': epoch,
         'model': model.state_dict(),
         'optimizer': optimizer.state_dict(),
         'scheduler': scheduler.state_dict()}
torch.save(state, 'model_path')

####方法二
torch.save(self.model.state_dict(), 'model_path')


                    module形式---建议模式
####方法一
state = {'epoch': epoch,
         'model': model.module.state_dict(),
         'optimizer': optimizer.state_dict(),
         'scheduler': scheduler.state_dict()}
torch.save(state, 'model_path')

####方法二
torch.save(self.model.module.state_dict(), 'model_path')


#######################################################################################


模型加载:  
        未用含"moduel"方式保存, 导致缺失关键“key”:Missing key(s) in state_dict
############### 方法 1: add
model = torch.nn.DataParallel(model)  # 加上module
model.load_state_dict(torch.load("model_path"))

############### 方法 2: remove
model.load_state_dict({k.replace('module.', ''): v for k, v in                 
                       torch.load("model_path").items()})

############### 方法 3: remove
from collections import OrderedDict
state_dict = torch.load("model_path")
new_state_dict = OrderedDict()   # create new OrderedDict that does not contain `module.`
for k, v in state_dict.items():
    name = k.replace('module.', '')
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

                        含"moduel"方式保存
####方法一
self.model.load_state_dict(torch.load("model_path")['model'])

####方法二
self.model.load_state_dict(torch.load('model_path))

 DDP训练流程:

  • 初始化 

    #初始化使用nccl后端(这个),当然还有别的后端,可以查看官方文档,介绍的比较清晰
    torch.distributed.init_process_group(backend="nccl")
  • 使用DistributedSampler

    DDP并不会自动shard数据如果自己写数据流,得根据torch.distributed.get_rank()去shard数据,获取自己应用的一份如果用Dataset API,则需要在定义Dataloader的时候用DistributedSampler 去shard:

  • 分布式训练

    model=torch.nn.parallel.DistributedDataParallel(model)
  • 参考文章https://zhuanlan.zhihu.com/p/95700549https://zhuanlan.zhihu.com/p/95700549

https://zhuanlan.zhihu.com/p/145427849icon-default.png?t=L892https://zhuanlan.zhihu.com/p/145427849 

DDP训练问题:

1.自定义的模型结构,继承自nn.Module,除__init__(), forward()等重写方法外,在模型结构类内部自定义的一些函数如loss()等,在DDP训练时调用方式需注意

model = YoulModel()
model.loss()  ### 找不到属性loss

#在使用net = torch.nn.DDP(net)之后,原来的net会被封装为新的net的module属性里

model.module.loss() 

<think>好的,用户想知道在PyTorch中使用DDP进行分布式训练时如何正确保存模型参数和状态。首先,我需要回忆一下DDP的工作原理。DDP是分布式数据并行,每个进程都有模型的副本,并且通过Ring-All-Reduce同步梯度。每个进程的参数在训练过程中应该保持一致,对吧? 那保存模型的时候,是否需要每个进程都保存呢?显然不需要,因为这样会导致重复保存,浪费资源。正确的做法应该是只在主进程(rank 0)保存模型。这时候需要考虑如何判断当前进程是否是主进程,可以通过torch.distributed.get_rank()来获取当前进程的排名,如果是0的话就执行保存操作。 接下来,用户可能关心保存哪些内容。模型的状态包括参数和优化器状态,可能还有学习率调度器的状态。使用model.module.state_dict()来获取模型的状态字典,因为DDP会将原始模型包裹在DistributedDataParallel中,所以需要访问module属性才能得到原始模型。 另外,用户可能遇到在加载模型后结果不一致的问题,这可能是因为保存时没有正确使用model.eval(),或者在模型结构中存在随机操作,比如Dropout或BatchNorm。这些在评估模式下需要固定,所以在保存前应该设置模型为eval模式,加载时同样要设置。 还有,用户提到的参考文章里提到保存模型时使用单卡保存,这和在DDP中只在主进程保存是一致的。同时,加载模型时需要注意设备的位置,使用map_location参数来确保模型加载到正确的设备上。 还需要考虑是否在保存前同步所有进程,确保参数一致。虽然DDP在每次迭代后会自动同步梯度,但保存模型时是否会有延迟?可能需要使用torch.distributed.barrier()来同步进程,确保所有进程在保存前参数一致。 总结一下,正确的步骤应该包括:1)仅在主进程保存;2)使用model.module.state_dict();3)保存优化器和调度器状态;4)处理随机层;5)正确加载模型。同时要注意多机多卡的情况,保存路径可能需要是共享的文件系统。 用户可能还会遇到加载模型后预测结果不一致的问题,这时候需要检查是否在加载时使用了相同的配置,比如model.eval(),以及是否有随机种子设置等问题。此外,如果使用多节点训练保存路径的访问权限和共享也需要考虑进去。 最后,应该生成相关问题,帮助用户深入理解DDP模型保存的相关问题,比如如何加载模型、多节点保存的注意事项,以及调试预测不一致的方法。</think>在PyTorch分布式数据并行(DDP训练中,正确保存模型需要特别注意以下关键点: ### 1. 基本保存方法 **仅在主进程保存**:通过判断当前进程是否为rank 0来避免重复保存 ```python if torch.distributed.get_rank() == 0: torch.save({ 'model_state_dict': model.module.state_dict(), # 必须通过model.module获取原始模型 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch }, 'checkpoint.pth') ``` ### 2. 关键技术细节 - **模型参数同步**:DDP通过Ring-All-Reduce架构保持各进程参数一致[^2],但保存前建议添加同步屏障: ```python torch.distributed.barrier() ``` - **随机层处理**:若模型包含Dropout/BatchNorm,保存前需切换评估模式: ```python model.eval() with torch.no_grad(): # 执行验证或保存操作 ``` ### 3. 多节点训练扩展 当跨多台机器训练时: - 确保所有节点能访问存储路径(建议使用共享文件系统) - 保存时指定设备映射: ```python torch.save(..., map_location='cuda:0') ``` ### 4. 完整保存示例 ```python def save_checkpoint(model, optimizer, epoch): if torch.distributed.get_rank() == 0: checkpoint = { 'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, f"model_epoch{epoch}.pth") print(f"Checkpoint saved at epoch {epoch}") ``` ### 5. 模型加载要点 ```python # 加载时需先初始化分布式环境 model = DDP(model) checkpoint = torch.load('checkpoint.pth', map_location='cuda:0') model.module.load_state_dict(checkpoint['model_state_dict']) ```
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值