PyTorch教程:深入理解TorchRec中的分片(Sharding)机制
tutorials PyTorch tutorials. 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials
引言
在推荐系统和大规模深度学习应用中,嵌入表(Embedding Tables)通常是模型中最消耗内存的部分。随着模型规模的不断扩大,如何在多个设备上高效地分布这些嵌入表成为了一个关键问题。PyTorch的TorchRec库提供了强大的分片(Sharding)功能,可以帮助我们解决这一挑战。
环境准备
在开始之前,我们需要确保环境配置正确:
- Python版本:需要Python 3.7或更高版本
- CUDA支持:推荐使用CUDA 11.0或更高版本以获得最佳性能
安装必要的软件包:
# 安装PyTorch和CUDA工具包
conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
# 安装TorchRec
pip3 install torchrec-nightly
# 安装多进程支持库
pip3 install multiprocess
安装完成后,需要重启运行时环境以使更改生效。
理解嵌入表分片
什么是嵌入表分片?
嵌入表分片是指将大型嵌入表分割成多个部分,并将这些部分分布到不同的计算设备上。这种方法可以:
- 解决单个设备内存不足的问题
- 提高并行计算效率
- 实现更好的负载均衡
TorchRec支持的分片策略
TorchRec提供了多种分片策略,每种策略适用于不同的场景:
- 表级分片(TABLE_WISE):将整个表放在一个设备上
- 行级分片(ROW_WISE):按行维度均匀分割表
- 列级分片(COLUMN_WISE):按嵌入维度均匀分割表
- 表行级分片(TABLE_ROW_WISE):针对快速设备互连优化的特殊分片
- 数据并行(DATA_PARALLEL):在每个设备上复制表
构建嵌入模型
让我们构建一个包含大表和小表的嵌入模型:
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict
# 定义大表和小表的配置
large_table_cnt = 2
small_table_cnt = 2
large_tables = [
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=128,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
) for i in range(large_table_cnt)
]
small_tables = [
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=128,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
) for i in range(small_table_cnt)
]
# 创建嵌入表集合
ebc = torchrec.EmbeddingBagCollection(
device="cuda",
tables=large_tables + small_tables
)
分布式模型并行
TorchRec使用DistributedModelParallel
来实现模型并行。下面是一个单进程执行函数的示例,模拟一个GPU rank的工作:
def single_rank_execution(
rank: int,
world_size: int,
constraints: Dict[str, ParameterConstraints],
module: torch.nn.Module,
backend: str,
) -> None:
# 初始化分布式环境
os.environ["RANK"] = f"{rank}"
os.environ["WORLD_SIZE"] = f"{world_size}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
# 设置设备
if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
# 创建分片计划
topology = Topology(world_size=world_size, compute_device="cuda")
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
# 生成分片模型
sharders = [EmbeddingBagCollectionSharder()]
plan = planner.collective_plan(module, sharders, pg)
sharded_model = DistributedModelParallel(
module,
env=ShardingEnv.from_process_group(pg),
plan=plan,
sharders=sharders,
device=device,
)
print(f"rank:{rank},sharding plan: {plan}")
return sharded_model
分片策略实践
表级分片(TABLE_WISE)
表级分片是最简单的分片方式,每个表完整地存放在一个设备上:
spmd_sharing_simulation(ShardingType.TABLE_WISE)
这种分片方式适合中小型表,能够很好地实现负载均衡。从输出可以看到,大表和小表被均匀地分配到两个GPU上。
行级分片(ROW_WISE)
行级分片将表按行维度分割,适用于单个设备无法容纳的超大表:
spmd_sharing_simulation(ShardingType.ROW_WISE)
从输出可以看到,每个表被均匀地分割成两部分,分别存放在两个GPU上。这种分片方式特别适合行数很多的嵌入表。
列级分片(COLUMN_WISE)
列级分片将表按嵌入维度分割,适用于嵌入维度很大的表:
spmd_sharing_simulation(ShardingType.COLUMN_WISE)
输出显示,每个表的嵌入维度被均匀分割。这种分片方式适合处理嵌入维度很大的表,能够有效解决计算负载不均衡的问题。
分片策略选择指南
选择合适的分片策略需要考虑以下因素:
-
表大小:
- 小表:表级分片
- 大行数:行级分片
- 大嵌入维度:列级分片
-
硬件配置:
- 设备内存大小
- 设备间互连速度
-
计算模式:
- 计算密集型
- 内存密集型
性能考虑
不同的分片策略会对性能产生不同影响:
-
表级分片:
- 通信开销低
- 内存使用不均衡
-
行级分片:
- 内存使用均衡
- 可能需要更多的通信
-
列级分片:
- 计算负载均衡
- 需要合并操作
结论
TorchRec提供的分片功能为大规模嵌入表的高效分布提供了强大支持。通过合理选择分片策略,我们可以:
- 突破单设备内存限制
- 提高计算资源利用率
- 优化模型训练和推理性能
理解这些分片策略的特点和适用场景,将帮助我们在实际应用中做出更明智的选择,构建更高效的推荐系统和大规模深度学习模型。
tutorials PyTorch tutorials. 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考