PyTorch教程:深入理解TorchRec中的分片(Sharding)机制

PyTorch教程:深入理解TorchRec中的分片(Sharding)机制

tutorials PyTorch tutorials. tutorials 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials

引言

在推荐系统和大规模深度学习应用中,嵌入表(Embedding Tables)通常是模型中最消耗内存的部分。随着模型规模的不断扩大,如何在多个设备上高效地分布这些嵌入表成为了一个关键问题。PyTorch的TorchRec库提供了强大的分片(Sharding)功能,可以帮助我们解决这一挑战。

环境准备

在开始之前,我们需要确保环境配置正确:

  1. Python版本:需要Python 3.7或更高版本
  2. CUDA支持:推荐使用CUDA 11.0或更高版本以获得最佳性能

安装必要的软件包:

# 安装PyTorch和CUDA工具包
conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y

# 安装TorchRec
pip3 install torchrec-nightly

# 安装多进程支持库
pip3 install multiprocess

安装完成后,需要重启运行时环境以使更改生效。

理解嵌入表分片

什么是嵌入表分片?

嵌入表分片是指将大型嵌入表分割成多个部分,并将这些部分分布到不同的计算设备上。这种方法可以:

  1. 解决单个设备内存不足的问题
  2. 提高并行计算效率
  3. 实现更好的负载均衡

TorchRec支持的分片策略

TorchRec提供了多种分片策略,每种策略适用于不同的场景:

  1. 表级分片(TABLE_WISE):将整个表放在一个设备上
  2. 行级分片(ROW_WISE):按行维度均匀分割表
  3. 列级分片(COLUMN_WISE):按嵌入维度均匀分割表
  4. 表行级分片(TABLE_ROW_WISE):针对快速设备互连优化的特殊分片
  5. 数据并行(DATA_PARALLEL):在每个设备上复制表

构建嵌入模型

让我们构建一个包含大表和小表的嵌入模型:

from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict

# 定义大表和小表的配置
large_table_cnt = 2
small_table_cnt = 2

large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=128,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    ) for i in range(large_table_cnt)
]

small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=128,
        num_embeddings=1024,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    ) for i in range(small_table_cnt)
]

# 创建嵌入表集合
ebc = torchrec.EmbeddingBagCollection(
    device="cuda",
    tables=large_tables + small_tables
)

分布式模型并行

TorchRec使用DistributedModelParallel来实现模型并行。下面是一个单进程执行函数的示例,模拟一个GPU rank的工作:

def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:
    # 初始化分布式环境
    os.environ["RANK"] = f"{rank}"
    os.environ["WORLD_SIZE"] = f"{world_size}"
    dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
    
    # 设置设备
    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    
    # 创建分片计划
    topology = Topology(world_size=world_size, compute_device="cuda")
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    
    # 生成分片模型
    sharders = [EmbeddingBagCollectionSharder()]
    plan = planner.collective_plan(module, sharders, pg)
    
    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model

分片策略实践

表级分片(TABLE_WISE)

表级分片是最简单的分片方式,每个表完整地存放在一个设备上:

spmd_sharing_simulation(ShardingType.TABLE_WISE)

这种分片方式适合中小型表,能够很好地实现负载均衡。从输出可以看到,大表和小表被均匀地分配到两个GPU上。

行级分片(ROW_WISE)

行级分片将表按行维度分割,适用于单个设备无法容纳的超大表:

spmd_sharing_simulation(ShardingType.ROW_WISE)

从输出可以看到,每个表被均匀地分割成两部分,分别存放在两个GPU上。这种分片方式特别适合行数很多的嵌入表。

列级分片(COLUMN_WISE)

列级分片将表按嵌入维度分割,适用于嵌入维度很大的表:

spmd_sharing_simulation(ShardingType.COLUMN_WISE)

输出显示,每个表的嵌入维度被均匀分割。这种分片方式适合处理嵌入维度很大的表,能够有效解决计算负载不均衡的问题。

分片策略选择指南

选择合适的分片策略需要考虑以下因素:

  1. 表大小

    • 小表:表级分片
    • 大行数:行级分片
    • 大嵌入维度:列级分片
  2. 硬件配置

    • 设备内存大小
    • 设备间互连速度
  3. 计算模式

    • 计算密集型
    • 内存密集型

性能考虑

不同的分片策略会对性能产生不同影响:

  1. 表级分片

    • 通信开销低
    • 内存使用不均衡
  2. 行级分片

    • 内存使用均衡
    • 可能需要更多的通信
  3. 列级分片

    • 计算负载均衡
    • 需要合并操作

结论

TorchRec提供的分片功能为大规模嵌入表的高效分布提供了强大支持。通过合理选择分片策略,我们可以:

  1. 突破单设备内存限制
  2. 提高计算资源利用率
  3. 优化模型训练和推理性能

理解这些分片策略的特点和适用场景,将帮助我们在实际应用中做出更明智的选择,构建更高效的推荐系统和大规模深度学习模型。

tutorials PyTorch tutorials. tutorials 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

尤翔昭Tess

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值