PyTorch分布式检查点(DCP)使用指南
tutorials PyTorch tutorials. 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials
概述
在分布式训练环境中,模型参数和梯度被划分到多个训练器上,这使得模型检查点的保存和恢复变得复杂。PyTorch分布式检查点(Distributed Checkpoint, DCP)提供了一套解决方案,可以简化这一过程。本文将详细介绍DCP的工作原理、使用方法以及与常规检查点方法的区别。
DCP核心概念
与传统检查点的区别
DCP与传统的torch.save
和torch.load
有以下主要区别:
- 多文件存储:每个检查点会生成多个文件,每个rank至少一个
- 原地操作:模型需要先分配存储空间,DCP直接使用这些存储
- 状态对象处理:自动调用
state_dict
和load_state_dict
方法
关键优势
- 并行保存/加载:支持从多个rank并行保存和加载模型
- 拓扑灵活性:可以在不同集群拓扑结构间重新分片
- 状态字典管理:自动处理跨模型和优化器的全限定名(FQN)映射
实践指南
准备工作
首先需要设置分布式环境:
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
模型定义
我们使用一个简单的FSDP包装模型作为示例:
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
状态管理封装
AppState
类封装了模型和优化器的状态管理:
class AppState(Stateful):
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
保存检查点
保存FSDP包装模型的完整流程:
def run_fsdp_checkpoint_save_example(rank, world_size):
setup(rank, world_size)
model = ToyModel().to(rank)
model = FSDP(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# 训练步骤...
state_dict = {"app": AppState(model, optimizer)}
dcp.save(state_dict, checkpoint_id="checkpoint_dir")
cleanup()
加载检查点
从检查点恢复模型的流程:
def run_fsdp_checkpoint_load_example(rank, world_size):
setup(rank, world_size)
model = ToyModel().to(rank)
model = FSDP(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
state_dict = {"app": AppState(model, optimizer)}
dcp.load(state_dict=state_dict, checkpoint_id="checkpoint_dir")
cleanup()
非分布式环境加载
DCP也支持在非分布式环境中加载检查点:
def run_checkpoint_load_example():
model = ToyModel()
state_dict = {"model": model.state_dict()}
dcp.load(state_dict=state_dict, checkpoint_id="checkpoint_dir")
model.load_state_dict(state_dict["model"])
格式转换工具
DCP提供了与torch.save
格式相互转换的工具:
命令行工具
python -m torch.distributed.checkpoint.format_utils <mode> <input> <output>
编程方式
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
# DCP转torch.save格式
dcp_to_torch_save("dcp_checkpoint", "torch_save.pth")
# torch.save转DCP格式
torch_save_to_dcp("torch_save.pth", "new_dcp_checkpoint")
最佳实践
- 状态管理:使用
AppState
等封装类简化状态管理 - 环境隔离:确保保存和加载时的环境配置一致
- 格式选择:根据使用场景选择合适的存储格式
- 错误处理:添加适当的错误处理和日志记录
总结
PyTorch分布式检查点(DCP)为分布式训练环境提供了强大的模型保存和恢复能力。通过本文的介绍,您应该已经掌握了:
- DCP的核心概念和优势
- 如何在分布式环境中保存和加载模型
- 如何在非分布式环境中使用DCP检查点
- 不同存储格式间的转换方法
DCP特别适合大规模分布式训练场景,能够有效解决参数分片和集群拓扑变化带来的挑战。
tutorials PyTorch tutorials. 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考