训练性能显著提升,字节跳动郑思泽详解 Triton-distributed 框架,实现大模型高效分布式通信与计算融合

2025 年,由 HyperAI 超神经主办的 Meet AI Complier 技术沙龙已经行至第 7 期,在社区小伙伴和多位行业专家的支持下,我们在北京、上海、深圳等地建立了多个据点,为开发者和爱好者提供交流平台,揭开创新技术的神秘面纱,直面一线开发者的应用反馈,共享技术落地的实战经验,聆听多角度的创新思维。

关注微信公众号「HyperAI 超神经」,后台回复关键字「0705 AI 编译器」,即可获取确认授权的讲师演讲 PPT 。

在「Triton-distributed:原生 Python 编程实现高性能通信」主题演讲中,来自字节跳动的 Seed Research Scientist 郑思泽详细解析了 Triton-distributed 在大模型训练中的通信效率突破、跨平台适配能力,以及如何通过 Python 编程实现通信与计算的深度融合。分享结束后,现场迅速进入提问高峰,围绕 FLUX 框架、 Tile 编程模型、 AllGather 与 ReduceScatter 优化等细节展开的探讨层出不穷,讨论聚焦核心技术难点与实践经验,切实促进了理论与应用的结合。

HyperAI 超神经在不违原意的前提下,对郑思泽老师的演讲分享进行了整理汇总,以下为演讲实录。

分布式训练的现实挑战

在当前大模型迅速演进的背景下,无论是训练还是推理,分布式系统已成为不可或缺的一环。我们也在这一方向上开展了编译器层面的探索,并已开源该项目,命名为 Triton-Distributed 。

当前主流的硬件互联方式包括 NVLink 、 PCIe 以及跨节点的网络通信。在理想条件下,H100 的 NVLink 单向带宽可以达到 450GB/s,但在国内大多数部署中,更常见的实际是 H800,其单向带宽只有约 200GB/s,整体通信能力和拓扑复杂度大幅下降。我们在项目中遇到的一个明显挑战便是由于带宽不足与通信拓扑不对称带来的系统性能瓶颈。

不同 GPU 集群的带宽

针对此,早期的分布式优化往往依赖大量手动实现的通信算子,包括 Tensor 并行、 Pipeline 并行、 Data 并行等策略,均需要精心编写底层通信逻辑。常见做法是调用如 NCCL 、 ROCm CCL 等通信库,但这类方案往往缺乏通用性和可移植性,开发成本和维护代价都较高。

在分析现有系统瓶颈时,我们总结了 3 个关键事实:

Fact 1:硬件带宽受限,通信延迟成为瓶颈

首先是基础硬件条件带来的限制。如果用 H100 训练大模型的话,计算延迟往往显著高于通信延迟,因此无需特别关注计算与通信的重叠调度。但在当前 H800 的环境下,通信延迟被明显拉长。我们评估过,在某些场景下,近一半的训练时间会被通信延迟空消耗掉,导致整体 MSU(Model Scale Utilization)显著下降。若不进行通信与计算的 overlap(重叠)优化,系统将面临严重的资源浪费问题。

硬件带宽延迟

在中小规模下,这种损耗尚可接受;但一旦模型扩展到数千张卡级别,比如在 MegaScale 或 DeepSeek 的训练实践中,那么累计的资源损失将达到百万甚至千万美元级别,这对企业而言是非常现实的成本压力。

推理场景同样如此。 DeepSeek 早期推理部署使用了多达 320 张卡,尽管后续进行了压缩和优化,但通信延迟依旧是分布式系统不可回避的核心问题。因此,如何在程序层面有效调度通信与计算、提升整体效率,成为我们必须正面应对的关键课题。

Fact 2: 通信开销高,直接影响 MFU 表现

在当前的大模型训练和推理中,通信开销始终是一个主要瓶颈。我们观察到,无论底层使用的是 NVLink 、 PCIe,还是不同代际的 GPU(比如 A100 、 H800),通信所占比例都非常高。特别是在国内的实际部署中,由于带宽限制更明显,通信延迟会直接拖慢整体效率。

对于大模型训练来说,这种高频的跨卡通信会显著拉低系统的 MFU 。因此,优化通信开销对于提升训练与推理性能,都是非常关键的提升点,也是我们重点关注的方向之一。

不同 GPU 配置的通信与计算占比

Fact 3: 可编程性与性能之间的鸿沟

目前在分布式系统中,可编程性与性能之间仍存在较大鸿沟。过去我们更关注单卡编译器的优化能力,比如如何在一张卡上发挥出色的性能;但当我们扩展到单机多卡,甚至跨节点的分布式系统后,情况就更加复杂。

一方面,分布式通信涉及大量底层技术细节,如 NCCL 、 MPI 、拓扑结构,且分散在各种专用库中,使用门槛较高。很多时候,开发者需要手动实现通信逻辑、手动调度计算与同步,开发成本和出错概率都很高。另一方面,如果有工具能自动处理分布式下复杂的通信调度和算子优化,就可以帮助开发者显著降低开发门槛,提升分布式系统的可用性和可维护性,这正是我们在 Triton-Distributed 中希望解决的问题之一。

不同通信编程方式的特点

基于前面提到的 3 个现实问题,我们在 Triton-Distributed 中提出了 3 个核心方向:

首先,推动通信与计算的重叠(overlapping)机制。在通信开销日益突出的分布式场景下,我们希望尽可能调度出计算与通信的并行窗口,提升系统整体效率。

其次,需要针对大模型的计算与通信模式进行深度融合与适配。比如在模型中常见的 AllReduce 、 Broadcast 等通信 pattern,我们尝试将其与计算的 pattern 做融合,从而减少同步等待、压缩执行路径。

最后,我们认为,这些优化应通过编译器完成,而不是依赖开发者手动编写高度定制化的 CUDA 实现。让分布式系统的开发更抽象、更高效,是我们努力的方向。

Triton-distributed 架构解析:原生 Python 实现高性能通信

我们希望在分布式训练中实现 overlapping,但真正落地并不容易。概念上,overlapping 是指通过多个 stream 并发执行计算和通信,以掩盖通信延迟。这在算子之间无依赖的场景下较为容易,但像 Tensor Parallel(TP)或 Expert Parallel(EP)中,必须先完成 AllGather 才能进行 GEMM,二者处于关键路径,重叠难度很大。

目前常见方法包括:一是将任务划分为多个 Micro-Batch,借助 Batch 间的独立性实现 overlap;二是在单个 Batch 内以更细粒度(如 tile 粒度)进行切分,通过 kernel fusion 达到并行效果。我们在 Flux 中也探索了这类切分与调度机制,同时,大模型训练中的通信模式高度复杂。例如 DeepSeek 在做 MoE 时需自定义 All-to-All 通信,以兼顾带宽和负载均衡;又如低延迟推理和量化场景下,NCCL 等通用库难以满足性能要求,往往需要手写通信 kernel,这些都提高了定制化成本。

因此,我们认为通信-计算融合的优化能力,应由编译器层承担,以应对复杂模型结构和多样硬件环境,避免重复手工实现带来的开发负担。

两层通信原语抽象

在我们的编译器设计中,采用了两层通信原语(primitives)抽象结构,以兼顾上层优化表达能力和底层部署的可落地性。

第一层是偏高层的原语,主要在 tile 粒度上完成计算调度,并提供面向通信的抽象接口。它以 rank 间的 push/get 操作作为通信抽象,并通过 tag 标识机制区分每一次通信行为,方便调度器追踪数据流与依赖关系。

第二层则更贴近底层实现,采用了一套类似于 Open Shared Memory 标准(OpenSHMEM)的原语体系。这一层主要用于映射到现有的通信库或硬件后端,实现真实的通信行为。

此外,在多 rank 的场景中,我们还需要引入 barrier 与 signal 控制机制,用于跨 rank 的同步。比如需要通知其他 rank 我方的数据已写入完毕,或等待某个 rank 的数据准备就绪时,这类同步信号就非常关键。

编译流程图

编译器架构与语义建模

在编译栈方面,我们的整体流程仍然基于原始的 Triton 编译框架。从源码开始,Triton 会先将用户代码转为抽象语法树(AST),再翻译成 Triton IR 。而在我们构建的 Triton-Distributed 中,我们对原有的 Triton IR 做了扩展,新增了一套面向分布式语义的 IR 层。这套分布式 IR 中,引入了对同步操作的语义建模,例如 wait 和 notify,用于描述 rank 之间的通信依赖关系;同时,我们还设计了一套面向 OpenSHMEM 的语义接口,以支持更底层的通信调用。

在实际代码生成阶段,这些语义可以映射为对底层通信库的外部调用(external call)。我们通过 LLVM 中间层,直接将这些调用链接到 OpenSHMEM 提供的 bitcode 版本的库(而非源码),以实现跨 rank 的高效共享内存通信。这种方式绕过了 Triton 不支持源码直接接入 external lib 的限制,使得共享内存相关的调用可以在编译期顺利完成符号解析与链接。

编译栈图

高层原语与底层执行的映射机制

在 Triton-distributed 中,我们设计了一套覆盖高层抽象与底层控制的通信原语体系。以 consumer_tile_wait 为例,开发者只需声明需要等待的 tile ID,系统会自动根据当前算子语义(如 AllGather)推导出通信目标的具体 rank 和 offset,完成同步逻辑。高层原语屏蔽了具体数据来源与信号传递的细节,提升了开发效率。

相比之下,底层原语则提供了更细粒度的控制能力。开发者需要手动指定 signal 指针、作用域(GPU 或 system)、内存语义(acquire 、 release 等)及预期值。这种机制虽然更复杂,但适用于对通信延迟和调度精准性要求极高的场景。

高层原语与底层原语

高层次的原语大致分为两类:信号控制和数据控制。在信号控制的语义中,我们主要定义了 3 类角色:producer 、 consumer 和 peer,它们之间通过读写 signal 实现同步,类似于分布式通信中的握手机制。对于数据传输,Triton-distributed 提供了 push 与 pull 两种原语,分别对应主动将数据发送到远端卡,或从远端拉取数据到本地卡。

所有底层通信原语均遵循 OpenSHMEM 标准,当前已支持 NVSHMEM 和 ROCSHMEM 。高层与底层原语之间具备明确的映射关系,编译器负责将简洁的接口自动转换为底层的同步与传输指令。通过这套机制,Triton-distributed 既保留了通信调度的高性能能力,也大幅降低了分布式编程的复杂度。

在 Triton-distributed 中,高层通信原语(如 notify 和 wait)的设计目标是以简洁语义描述跨卡同步需求,同时由编译器负责将其翻译为对应的底层执行逻辑。以 notify 为例,它与 wait 构成同步语义的一对:前者用于发送通知,后者用于等待数据准备完成。开发者只需指定 tile ID,系统即可根据算子类型与通信拓扑,自动推导出通信目标、信号偏移等底层细节。

具体的底层实现会因部署环境而异。例如在 8 卡 GPU 的场景中,这类同步可通过线程内的 _syncthreads() 与 atomic_dd 实现;在跨机部署中,则依赖于如 NVSHMEM 或 ROCSHMEM 提供的 signal_up 等原语完成等效操作。这些机制共同构成了高层语义与底层原语之间的映射关系,具有良好的通用性和可扩展性。

高层语义与底层原语之间的映射关系

以一个 GEMM ReduceScatter 的通信场景为例:假设系统中有 4 张 GPU,每个 tile 的目标位置由预先计算的元信息(如每个 rank 的 tile 分配量、 barrier 数量)决定。开发者只需在 Triton 编写的 GEMM kernel 中添加一条 notify 语句,而 ReduceScatter kernel 端则用 wait 来同步接收数据。

整个过程可在 Python 中表达,也支持双 stream 启动的 kernel 模式,通信逻辑清晰且易于调度。这一机制不仅提高了跨卡通信编程的可表达性,也大幅降低了底层实现的复杂度,为分布式大模型的高效训练与推理提供了强有力的基础能力支持。

多维度的 Overlapping 优化:从调度机制到拓扑感知

虽然 Triton-distributed 已经提供了相对简洁的高层通信原语接口,但在实际编写和优化 kernel 的过程中,仍存在一定的技术门槛。我们观察到,尽管原语设计具备良好的表达能力,但真正能够灵活运用并深入优化的用户仍然有限。本质上,通信优化仍是一项强依赖工程经验和调度理解的工作,目前仍需由开发者手动控制。围绕这个问题,我们总结出一些关键优化路径,以下为 Triton-distributed 中的典型实现策略。

Push vs Pull:数据流向与 barrier 数控制约

在通信与计算的重叠优化中,Triton-distributed 提供了 push 和 pull 两种数据传输方式。虽然它们在语义上仅仅是「主动发送」与「被动拉取」的方向差异,但在实际的分布式执行中,其性能表现和调度控制能力却存在明显不同。

以 barrier 数量为例,pull 模式通常需要设置 2 个 barrier:一个用于确保本地数据在被对方拉取前已经准备好,另一个则用于保护该数据在整个通信周期内不会被本地任务修改,从而防止数据不一致或读写冲突。而 push 模式则只需要在数据写入远端后设置一个 barrier,用以同步所有设备即可,整体控制更简单。

但 pull 模式也有其优势,它允许本地节点主动控制数据拉取顺序,从而更精确地调度通信时机与计算重叠关系。当我们希望最大化 overlap 效果、实现通信与计算的并行性时,pull 提供了更高的灵活性。

总体来看,如果主要目标是提升 overlap,则推荐使用 pull;而在一些纯通信任务中,如单独的 AllGather 或 ReduceScatter kernel,push 模式因其实现简洁、开销更小而更为常见。

Push 模式与 Pull 模式流程图

Swizzling 调度:按数据局部性动态调整顺序

通信与计算的重叠不仅依赖于原语选择,还与调度策略密切相关。其中,Swizzling 是一种基于拓扑感知的调度优化手段,旨在减少跨卡计算过程中的执行空闲。在分布式视角下,可以将每张 GPU 卡视为一个独立的执行单元。由于每张卡初始持有的数据片段不同,若所有卡从相同 tile 索引开始计算,部分 rank 将不得不等待数据就绪,导致执行阶段出现长时间的空闲,从而拉低整体计算效率。

Swizzling 的核心思想是:根据每张卡本地已有数据的位置动态调整起始计算偏移。例如,在 AllGather 场景中,每张卡可以优先处理自身持有的数据,同时发起对远端 tile 的拉取,实现通信与计算的并发调度。若所有卡一律从 tile 0 开始处理,只有 rank 0 能立即开始计算,其余 rank 则将因等待数据而产生串行延迟。

更复杂的情形如跨机 ReduceScatter 场景中,Swizzling 策略还需结合网络拓扑进行设计。以两台节点(node)为例,合理的调度方式是:优先计算对方节点所需的数据,尽早触发跨机 point-to-point 通信;而在传输过程中,再并行计算本地节点所需数据,最大化通信与计算的 overlap 效果。

目前,这类调度优化仍由编程者控制,以避免编译器在通用优化中牺牲关键性能路径。我们也意识到,理解 Swizzling 等细节对开发者有一定门槛。未来,我们希望通过提供更多实际案例和模板代码,帮助开发者更快掌握分布式算子开发模式,逐步构建起开放、高效的 Triton-distributed 编程生态。

Swizzling 优化流程

非完美分块调度:跨 rank tile 的处理优先级

在实际的大模型训练与推理场景中,算子的输入 shape 往往并不规整,尤其是在 token 长度不固定的情况下,tile 分块也难以保持整齐划一。这种非完美分块(Imperfect Tiling)会导致部分 tile 横跨多个 rank,即同一个 tile 的数据分布在多个设备上,增加了调度与同步的复杂性。

以 AllGather GEMM 为例,假设某个 tile 同时包含了本地和远端的数据。如果从这个 tile 开始计算,则必须等待远端数据先完成传输,进而引入额外的 bubble,影响整体计算的并行性。更优的做法是:跳过这个跨 rank tile,优先处理完全本地可用的数据,将等待远端输入的 tile 调度至最后执行,从而实现通信与计算的最大重叠。

而在 ReduceScatter 场景中,调度顺序则应反向处理。由于跨 rank tile 的计算结果需要尽早发送给远端,最佳策略是:优先处理那些被远端节点依赖的 tile,以便尽早完成跨机数据发送,减少远端的依赖。

非完美分块调度

MoE 下的 Dynamic Sorting 策略

在 MoE(Mixture-of-Experts)模型中,token 需要根据路由结果被分发至多个 expert,通常伴随 All-to-All 通信与 Group GEMM 计算。为了提升通信与计算的重叠效率,Triton-distributed 引入了 Dynamic Sorting,按计算任务对通信数据的依赖强度进行分阶段调度,优先处理数据依赖较少的部分。

这种排序方式确保了每一阶段的计算都能以尽可能低的通信阻塞开始,从而在 All-to-All 与 Group GEMM 之间实现更好的 overlap 效果。整体调度从数据依赖最少的 tile 开始,逐步扩展至依赖复杂的数据块,最大程度提升了执行并发性。

Dynamic Sorting

面向硬件的通信加速

Triton-distributed 还支持结合特定硬件能力进行通信优化,尤其是在使用 NVSwitch 架构时,可利用其内置的 SHARP Accelerator 执行低延迟的通信计算。该模块可在交换芯片内完成如 Broadcast 、 AllReduce 等操作,实现数据在传输路径中的聚合加速,减少延迟与带宽消耗。相关指令已集成进 Triton-distributed,具备相应硬件的用户可直接调用,构建更高效的通信 kernel 。

AOT 编译优化:降低推理延迟开销

Triton-distributed 引入了 AOT(Ahead-of-Time,提前编译)机制,专门针对推理场景中对延迟极度敏感的需求进行优化。 Triton 默认采用 JIT(Just-In-Time compilation,即时编译)编译方式,函数首次执行时存在显著的编译与缓存开销。

AOT 机制则允许用户在运行前将函数预编译为字节码,推理阶段直接加载执行,避免了 JIT 编译过程,从而有效降低了编译及缓存带来的延迟。基于此,Triton-distributed 对 AOT 机制进行了扩展,现已支持分布式环境中的 AOT 编译与部署,进一步提升了分布式推理的性能表现。

性能实测与案例复现

我们对 Triton-distributed 在多平台、多任务场景下的性能进行了全面测试,涵盖 NVIDIA H800 、 AMD GPU 、 8 卡 GPU 与跨机集群,并对比了 PyTorch 、 Flux 等主流分布式实现方案。

在 8 卡 GPU 上,Triton-distributed 在 AG GEMM 和 GEMM RS 任务中相较 PyTorch 实现有显著加速,相比手工优化的 Flux 方案也取得更优性能,得益于 Swizzling 调度、通信 offload 和 AOT 编译等多重优化。同时在 AMD 平台上对比 PyTorch + RCCL 的组合,虽然整体加速幅度略小,但同样取得显著优化,限制主要来自测试硬件算力偏弱和非 switch 拓扑。

在 AllReduce 任务中,Triton-distributed 在从小到大的多种消息尺寸下,在我们测试的硬件配置中,相比 NCCL 均有明显加速,平均加速约 1.6 倍。在 Attention 场景中,我们主要测试了 gather-KV 类型的 attention 操作。相较于 PyTorch Touch 的原生实现,Triton-distributed 在 8 卡 GPU 上的性能可达约 5 倍提升;同时也优于开源的 Ring Attention 实现,提升幅度约为 2 倍。

跨机测试方面,AG GEMM 提速 1.3 倍,GEMM RS 提速 1.4 倍,表现略低于 Flux,但在 shape 灵活性和可扩展性上更具优势。我们还测试了高速推理场景下的单 token decoding,在 1M token context 下延迟可控制在 20–30 微秒,兼容 NVLink 与 PCIe 。

此外,我们对 DeepEP 中的分布式调度逻辑进行了功能复现,主要对齐其 All-to-All 路由与上下文分发策略。在 64 卡以内的场景下,Triton-distributed 的性能与其基本持平,部分配置下略有提升。

最后,我们还提供了基于 Qwen-32B 的 prefill 与 decode Demo,支持在 8 卡 GPU 上部署运行,实测可获得约 1.2 倍的推理加速效果。

打造开放的分布式编译生态

目前我们正面临定制化 overlapping 场景的挑战,过去主要依赖手工优化解决,工作量大且成本高。为此,我们提出并开源了分布式的 Triton-distributed 框架。虽然它是基于 Triton 实现的,但其实不论各家公司使用何种编译器,或者底层通信库如何,都能将其集成进来,打造一个开放的分布式生态。

在国内乃至全球,这一领域仍较为空白。我们希望借助社区的力量,吸引更多开发者参与进来,无论是在语法设计、性能优化,还是支持更多类型的硬件设备方面,共同推动技术进步。最后,我们取得了不错的性能表现,相关示例也全部开源,欢迎大家积极提 issue 交流,也期待更多小伙伴加入,共创未来!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值