多GPU运行保存加载恢复checkpoint的几个关键

第一部分:认识多GPU的DataParalle model

在这里插入图片描述
在pytorch中如果使用单机器多块GPU时,会有一些小的注意事项,似乎大部分人都找不到合适的完整的介绍,这里把之前总结的做一个汇总,希望能帮更多人建立完整的知识框架。

第1层:认识model本身

类型1:
如果是cpu model或单GPU model,2种形式(sequential model和sequential&OrderedDict model)

Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU())

一种是sequential model如上,引用具体层的方式是model[0],因为sequential会自动对layers编号,可类似于list切片方式调用

Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU())

另一种是sequential&OrderedDict model,引用引用方式是model.conv1,因为sequential针对ordereddict进行了优化可以直接通过.来调用层

类型2:
如果是data paralle model:可以看到是在原model基础上wrap了一个module外壳如下

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    …))

引用方式类似OrderedDict model的嵌套:
model.module就是引用内部的真实模型
model.module.conv1就是引用具体层(前提是sequential内包含了OrderedDict)
model.module[0]也是引用具体层(前提是sequential内不包含OrderedDict)

第2层:认识checkpoint文件

通过torch.save(name, dir)保存的就叫checkpoint文件,可以存一个dict或存一个state_dict (OrderedDict)。一般的dict文件用来保存模型运行的状态信息和参数,state_dict用来保存参数是dict的一部分。如下是一个典型的checkpoint数据结构
{meta: dict
state_dict: OrderedDict
optimizer: dict}

第3层:认识state_dict字典

state_dict必然是一个OrderedDict数据类型,保存的内容就是所有深度学习需要优化的参数。

如果是Data paralle model在模型的state_dict中保存的内容会额外增加以module作为开头

odict_keys(['module.features.0.weight', 
		      'module.features.0.bias', 
		      'module.features.3.weight', 
		      'module.features.3.bias', 
		      'module.classifier.0.weight', 
	             'module.classifier.0.bias', 
		      'module.classifier.2.weight', 
		      'module.classifier.2.bias', 
		      'module.classifier.4.weight', 
		      'module.classifier.4.bias'])

第二部分:处理模型的保存和加载的流程

1. save:

通过torch.save(name, dir)完成,先组合checkpoint, 然后保存
参考:https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch

方案1:直接save state_dict(OrderedDict),且只save state_dict,后边只是用于inference。这样就会有2种形式state_dict存在,一种不带module前缀,一种带module前缀。
注:state_dict需要先.cpu()

方案2:直接save state_dict(OrderedDict),且只save state_dict,后边只是后边用于inference。但state_dict格式都统一成不带module的形式。
注:torch.save(model.state_dict, filepath) 此时保存的是OrderedDict,但可能有两种形式
注:torch.save(model.module.state_dict, filepath)此时保存的也是state_dict,但data paralle不会再带有module前缀

方案3:间接save整个training status(dict),不只save state_dict,还save epoch/iter/optimizer state_dict等状态参数,后边用于回复训练,这个dict需要自己组合
方案3是最常用的方案,因为他适用范围更广,不但可以training也可以inference

2. load checkpoint:

通过torch.load(checkpoint, map_location)完成

通过map_location控制加载位置:
可以加载到cpu: map_location = lambda storage, loc: storage
可以加载到GPU: map_location = lambda storage, loc: storage.cuda(0)

3. 获得state_dict:

(1)获得checkpoint后需要对checkpoint判断后处理获得state_dict:
如果checkpoint是OrderedDict,那么可以直接得到state_dict
如果checkpoint是dict,那么可以从字典中获得state_dict
如果state_dict包含module前缀,那么需要先去除module前缀,下面是一种处理前缀方式

if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

核心原因:saving DataParallel wrapped model can cause problems when the model_state_dict is loaded into an unwrapped model. 保存了wrapped model然后加载到unwrapped model所以必然出错。
参考:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/4

4. load state dict:

通过load_state_dict(model, state_dict)完成
(1)模型加载state_dict前需要对model判断处理:
如果是不带module的model,则加载不带module的state_dict:load_state_dict(model, state_dict)
如果是带module的model,则取内层model(相当于去掉module)然后加载不带module的state_dict,如下:load_state_dict(model.module, state_dict)

关键:带module wrap的model,可以直接加载带module前缀的state_dict,此时model和state_dict都不需要做去module化处理,model.load_state_dict(state_dict_w/_module)即可;当然也可以model和state_dict同时做去module化处理,model.module.load_state_dict(state_dict_w/o_module)

<think>好的,用户的问题是关于如何在PyTorch保存训练过程中的checkpoint以及如何加载指定轮次的.pt模型文件。我需要先回顾一下PyTorch的相关知识,确保自己记得正确的做法。 首先,保存checkpoint通常涉及在训练循环的每个epoch结束后,将模型的state_dict、优化器的state_dict以及当前的epoch和损失等信息保存到一个字典中,然后使用torch.save()保存为.pt或.pth文件。这部分我记得没错,但可能需要详细说明步骤,比如如何组织字典里的内容。 接下来是加载checkpoint的部分。用户需要加载指定轮次的模型,所以应该指导他们如何选择文件,然后使用torch.load()加载字典,再分别将模型和优化器的状态加载回去。需要注意的是,如果要在训练中断后继续训练,必须恢复优化器的状态和当前的epoch数;如果只是用于推理,可能不需要优化器的部分。 用户可能遇到的问题包括文件路径错误、版本不兼容导致加载失败,或者加载后模型表现不一致。这时候需要提醒用户检查文件是否存在,确保模型结构一致,以及是否在加载前调用了model.eval()模式。 另外,用户可能想了解如何自动保存checkpoint,比如每个epoch保存一个,或者只保留最好的几个。这时候可能需要提到在训练循环中加入条件判断,或者使用最大数量的检查点,覆盖旧的保存文件。 还要考虑代码示例的准确性,确保示例中的关键步骤都有注释,比如模型和优化器的定义部分可能需要用户根据实际情况替换。同时,保存的文件名应该包含epoch数或其他标识符,方便后续查。 用户可能没有提到的需求包括如何在加载时处理不同的设备(比如从GPU保存的模型加载到CPU),这时候可以在注意事项里提到map_location参数的使用。此外,保存其他信息如学习率调度器的状态可能也是有用的,可以适当补充。 最后,生成的相关问题需要围绕checkpoint保存加载展开,可能涉及最佳实践、恢复训练、模型评估等。要确保这些问题对用户有进一步的帮助,引导他们深入理解相关知识点。</think>在PyTorch保存加载训练检查点(checkpoint)是模型训练的重要实践,以下是具体实现方法: ### 一、保存Checkpoint的步骤 1. **定义保存路径** ```python checkpoint_dir = "./checkpoints/" os.makedirs(checkpoint_dir, exist_ok=True) ``` 2. **在训练循环中保存** ```python for epoch in range(num_epochs): # 训练过程... # 创建检查点字典 checkpoint = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': current_loss, 'accuracy': current_acc } # 保存检查点(每个epoch保存) torch.save(checkpoint, f"{checkpoint_dir}epoch_{epoch+1}_checkpoint.pt") # 可选:只保留最佳模型 if current_acc > best_acc: best_acc = current_acc torch.save(checkpoint, f"{checkpoint_dir}best_checkpoint.pt") ``` ### 二、加载Checkpoint的方法 1. **加载指定轮次模型** ```python def load_specific_checkpoint(epoch_num): checkpoint_path = f"./checkpoints/epoch_{epoch_num}_checkpoint.pt" checkpoint = torch.load(checkpoint_path) model = YourModelClass() # 需要先实例化模型 model.load_state_dict(checkpoint['model_state_dict']) optimizer = optim.Adam(model.parameters()) # 需要匹配优化器类型 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return model, optimizer, checkpoint['epoch'], checkpoint['loss'] ``` 2. **继续训练使用示例** ```python # 加载第10轮的检查点 model, optimizer, start_epoch, _ = load_specific_checkpoint(10) # 继续训练 for epoch in range(start_epoch, num_epochs): # 训练循环... ``` ### 三、关键注意事项 1. **文件管理**:建议文件名包含时间戳和epoch编号,如`resnet50_20240520_epoch50.pt` 2. **设备兼容性**:使用`torch.save(model.module.state_dict())`处理GPU训练[^1] 3. **版本控制**:保存训练代码版本,防止模型与代码版本不匹配 4. **最小信息原则**:推荐保存完整模型时使用`torch.save(model, 'full_model.pth')`,但检查点建议分开保存参数 ### 四、推荐目录结构 ``` ├── checkpoints/ │ ├── epoch_1_checkpoint.pt │ ├── epoch_10_checkpoint.pt │ └── best_checkpoint.pt ├── utils/ │ └── checkpoint_manager.py └── train.py ``` ### 五、最佳实践 1. 使用`torch.save()`的`_use_new_zipfile_serialization`参数确保兼容性(PyTorch 1.6+默认启用) 2. 加载时使用`map_location`参数处理设备差异: ```python checkpoint = torch.load('model.pth', map_location=torch.device('cpu')) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值