原论文:GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
在上一篇博客GShard中的混合专家(MoE)中,着重于其中的MoE的相关介绍,GShard还有许多其他亮点,像自动并行训练、高效资源利用、弹性扩展等就搬到这里来吧(主要在附录部分)(请在开始本篇之前看看上篇)(写的有点零碎,俺菜鸡能力差点,还有一点小原因可能是实验部分多少有点)(越写越没有信心让大家看了😳)。
别看他吹的神乎其神,说不定实际上就那样呢~(网络酸民口气.jpg)
深度学习模型的神通众所周知,硬件的计算能力是其必备需求,底层硬件在不断进步的同时,神经网络所依赖的软件系统也随之发展。加速器(如GPU、TPU)能高度并行,编程也就需要了解硬件的底层工作方式,用适合的代码来最大化硬件性能,很复杂,于是催生出许多高层次框架(如PyTorch、Tensorflow),它们简化了神经网络的构建过程,又抽象化了许多硬件相关的细节,依赖于一些低层的库(如Nvidia GPU的CUDA、Google TPU的XLA)来高效驱动硬件。
并行框架大杂烩
在分布式异构的环境中进行编程也是工程师一大难点,弄清分布式计算到底是怎么做的让人千头万绪苦不堪言😰,深度学习框架正在尝试把用户从这样的负担里面解救出来。例如TensorFlow就已经支持数据并行,能按节点设备划分计算图,实现基本的模型并行;专门用于分布式训练和大规模模型并行化的扩展库Mesh TensorFlow,构建在TensorFlow之上,采用SPMD(single program, multiple data)编程模式,所有计算节点运行相同的程序,处理不同的数据,和数据并行简单地分配数据不同,它侧重于如何在多个设备上执行模型的不同部分;GShard只需用户提供一点注释,编译器会根据注释自动进行划分,和深度学习框架FlexFlow通过自动化搜索来发现操作在计算图中的最佳划分策略不同,这里关注的是如何将带有注释的计算图转换成可在多个设备上并行执行的任务;Weight-update sharding是另一种基于XLA的自动并行化改造,专注于TPU集群的性能优化,在分布式训练中将权重更新分配到多个计算节点或设备上;Zero(可能是我最早读的几篇分布式相关的论文了)做了一系列优化来减轻并行训练设备的内存冗余,通过分别划分权重、激活值、优化器状态,将模型规模扩展到了1700亿参数,相较之下GShard就更具一般性,不用区分这些张量,上面那些特定的划分技术都可以通过对相应张量做简单注释实现,参数量可超一万亿,设计选择更多样。
GShard中的XLA SPMD划分器
这个能够基于分片注释,自动化分计算图的编译器的架构到底长什么样?分片注释(sharding annotations)指示了每个张量该如何跨设备分布;SPMD分区器(partitioner)是一个编译器组件,会将计算图放到所有设备上并行执行,已在XLA的编译器上进行了实现;XLA,Google的加速线性代数(accelerated linear algebra)编译器,许多前端框架,如Tensorflow、PyTorch、JAX、Julia都具备将计算图表示转化为XLA HLO图的降级逻辑(更接近硬件),因为降级逻辑解决了大部分复杂操作,故XLA只需支持较少的基础操作,操作符集合比这些前端框架小得多,可以专注于优化这些操作并提高执行效率,编译过程更加简洁和高效。作者也强调,虽然他们的工作是在XLA上实现的,但所采用的技术方法和思想是通用的,也可以应用到其他机器学习框架的IR(中间表示)中。
XLA将计算建模成一个数据流图(dataflow graph),节点代表图中的操作符,边代表张量在不同操作符之间流动的过程。分区器根据输入输出设备分配信息,调整每个操作符的执行粒度,以适应并行环境。当计算图被分区时,跨设备传输数据传输十分重要,为了优化大规模计算的性能,必须定义并优化核心的通信原语,这些原语负责设备间的数据交换,并根据不同硬件平台进行调整和优化。

上图将两种划分方法——SPMD和MPMD(Multiple Program Multiple Data)进行了比对,将向量点积()沿着维度
划分到4台设备上,每台设备算出本地结果后用AllReduce进行合并。MPMD会为每台设备生成独立的操作符,可扩展性受限,而SPMD中每台设备执行的是相同的程序,通过对数据的动态划分来优化计算。
如何跨分区通信?
Einsum是MoE模型中的关键操作,在XLA HLO中通常会被表示为Dot操作(矩阵乘法),操作数(LHS或RHS)可以是多维张量,通常包括三种类型的维度,对应不同计算需求:
1. 批处理维度(batch dimension),可以并行处理的维度,是可以并行执行的独立计算单元,必须在操作数和输出中都有,且每个输出仅依赖于对应批次的操作数。
2. 压缩维度(contracting dimension),LHS和RHS中的公共维度,在输出中会被总和消除掉。
3. 非压缩维度(Non-contracting dimension),每个操作数都有自己的非压缩维度集,这些维度会被继承到输出中。
分片传播时会优先选批处理维度,这样可以避免跨分片通信,但在以下情况中是避免不了的(先把算法2搬出来,之前不是有些Einsum操作很迷惑吗,这里都给解释了!)(再说一遍各符号含义)(:局部组数;
:每组token数;
:每个token的特征维度数;
:专家数;
:专家容量;
:前馈隐藏层维度数。组数和专家数得是设备数
的倍数,
)

1. 更换分片维度(resharding),例如算法2第三行的einsum操作,输入张量都是按组划分,如果输出还是按组划分的话,完全不用跨设备通信,毕竟组间结果互不影响,但输出变成按专家划分,就会它会先本地执行einsum,再利用AllToAll将全局重划到目标维度,如下图流程

2. 累积分片结果(accumulating partial results),一个简单的Matmul操作,将在压缩维度划分后的局部结果AllReduce成最终结果,如下图

配个我理解的例子( 就像图1那样啦)(Matmul矩阵乘法)(字略丑😳)

3. 在循环中切片(slicing in a loop),如果两个操作数被划分在了不同的非压缩维度上,就像下图中的LHS和RHS分别被划分成了A组和C组,希望将结果划成C组,但LHS中一组和RHS中一组的dot-product的结果只是结果一组中的一个数字 ,需要collective-permute在多个计算节点之间循环迭代出结果中的一组。全过程系统都不会用到全尺寸的张量,而是循环处理切片,内存效率高

编译器优化
SPMD划分器创建了各种操作符,XLA编译器利用TPU融合能力(fusion capabilities,指将多个计算操作合并成一个,减少计算过程中不必要的中间数据存储和传输) 和代码移动优化(重排代码中的某些操作来减少不必要的计算和内存访问),可以隐藏这些数据格式化操作的开销。
除算法2带标注版本中提到的两种分片常用API replicate()和split()外,还有很多高级分片策略用来减少数据传输,例如shard(tensor, device_assignment),输入的是待分区的张量和设备分配信息,device_assignment的维度一定要和tensor的相同,指示tensor的每个维度如何划分,假设有一个3D张量,形状为[256, 1024, 8192],对应device_assignment的形状为[2,1,4],则这个张量的第一、二、三个维度分别会被划分成2、1、4个部分,最终每个分区的大小为[128,1024,2048],一个简单的代码示例:
tensor = tf.random.normal([256, 1024, 8192]) # 创建一个随机张量
device_assignment = [[0, 1], [0], [0, 1, 2, 3]] # 设备分配:2x1x4
sharded_tensor = shard(tensor, device_assignment) # 执行分区
还得考虑设备拓扑、张量分片之间通信。(这里感觉蛮工程了,没有之前新奇算法那么惊鸿一瞥的感觉了,读论文好累啊😭)
计算和存储效率
在GShard模型中,主要有三种 内存开销 :复制的权重(如transformer的前向层)、分布式权重(如MoE前向层)、激活值(每层用于前向和反向传播的输出)

上图显示了不同模型中每台设备的内存使用分布,固定层数时,即使增加专家数,激活值和权重的内存消耗依旧保持,它们只会随着层数线性变化。当超出可用内存时,会采取自动重算反向传播时的激活值以减少内存消耗,这也是模型MoE(2048E,60L)中激活值消耗的内存少于MoE(2048E,36L)的原因。

上图对 执行时间 做了一个分解,并和roofline模型做了比较。只展示了前向过程,后向过程分解类似。roofline模型是一种用于分析和评估并行计算性能的可视化模型,通过图形方式展现在给定硬件资源条件下,程序或算法的理论最佳性能(即“屋顶”)(计算、内存、通信带宽受限情况下,FLOPS、内存带宽、互联带宽等性能的最佳表现,一种非常乐观的估计)。上图中最小规模模型,128个专家时可达超70%的roofline性能,扩展到16倍大,2048个专家时,执行时间变成原先1.7倍,仍可达48%的roofline性能。
“Gate Einsums” 表示的是算法2中的前两个和最后一个einsum操作,第一个是计算softmax操作的输入的映射,消耗,这是这一层开销很小的一部分了,另外两个是划分tokens和结合各专家结果,用独热矩阵有效实现Gather,代价较高,也不过是常数
,与专家数无关。专家数从128变到2048时,这些einsum操作的执行时间仅翻倍。
其余的单设备门控计算涉及许多通用计算,如ArgMax和Cumsum,受内存限制或本质上就是顺序的,故设计时不能很好地利用TPU,多数时间花在了顺序的Cumsum操作上,将代表每个token对应的专家的独热矩阵,转化成每个专家对应的token的独热矩阵。“Gate cumsum” 这块执行时间,和softmax前的einsum操作近似,仅差一个很小的常数因子,128个专家时可忽略,2048个专家时不到总花销的10%。解释一下,Cumsum(累加和)操作就是返回一个与输入数组相同形状的新数组,其中每个位置的元素是输入数组从起始位置到该位置的所有元素的累积和,如下
import numpy as np
x = np.array([1, 2, 3, 4]) # 多维张量可以指定维度
result = np.cumsum(x)
print(result) # 输出[ 1 3 6 10]
“MoE dispatch and combine” 代表AllToAll的跨设备通信开销,是门控中最重要的部分,花销,专家数从128涨到2048时,它们在总执行时间的占比从16%涨到36%。
通信开销
通信微基测试和各操作可扩展性(communication microbenchmarks and per-operator scability)
在MoE架构中两个关键的集体通信原语就是AllReduce(整合各分片的结果)和AllToAll(更换分片维度)。下图展现了分片数从16到2048时,它们的性能变化趋势。TPU上AllReduce的执行时间和设备数无关,图中曲线的跌宕是源于具体设备拓扑结构的差异;而AllToAll的花销则会随着分片变多而越发昂贵,呈亚线性增长。

为什么在2D TPUs集群中,AllToAll的大致开销是?每个分片发送的数据量固定,总数据量就是
,同时每个数据分片平均要跳
跳,网络中有
条设备到设备的链路,所以在带宽受限的情况下,执行时间就是
。即使在延迟受限的场景下,执行时间仍是
,分块数从16变到2048,增长到128倍时,AllToAll的执行时间也仅增长到原来的9倍,这样看用resharding实现跨分区调度还蛮高效的。
AllGather(replicating one operand)的输出是输入的倍,输入固定时,通信开销就是
。CollectivePermute(replicating one operand)是1对1的通信模式,在源-目标节点距离较近的情况下,输入大小固定时,开销
。

上图是几个分区操作的扩展性,最后一列中通信原语的缩写有AR(AllReduce)、AG(AllGather)、CP(CollectivePermute)、AA(All To All)。多数运算符在计算和通信方面都具有次线性的可拓展性。最后两个Matmul操作在每个分区上的计算和通信规模都在,不是说没有效划分,而是这两完整算下来就很费劲。“*”注释:将Einsum操作调度到计算中,根据不同的维度和分配方式来执行,
;“**”注释:I/O是输入/输出特征维度,B是批量维度,X/Y是输入的空间维度;x/y表示卷积核大小。
对稠密模型下手
GShard并不限于稀疏矩阵,论文中也将其应用到了一个很大的高达数万亿参数的稠密transformer上,示例代码已 开源 ,其中还有如何在公有云上训练这个模型的详细步骤。GShard张量同时在多个维度上进行划分,例如可以同时在批量和模型的维度上划分激活值,这允许长序列和全局批量数小于设备数的情况的出现。

上表中参数量从64B扩展到1T时,性能呈线性增长;当进一步扩展到4T时,由于小批量的原因,计算/通信比率降低,通信瓶颈开始占据主导,也许可以通过将模型扩展到更多的TPU核上或执行梯度累计来实现更大的批量,为确保比较的公平性,表中的4T模型并未做此优化。“MXU util”指的是TPU中矩阵单元利用率。
如何对稠密Transformer进行分片的?作者使用形为[X,Y]的二维网络布置TPU设备,X用于划分批量维度,Y用来划分模型维度,使激活值在所有设备上进行全划分,具体通过ReduceScatter(相当于AllReduce+DynamicSlice,但实现上更有效)运算符进行分片,在下一层通过AllGather合并分片。为进一步减少每台设备上的权重存储,将权重沿X维度划分,在需要的时候,会通过subgrouped AllGather操作跨设备合并分片权重,这种操作模式和权重更新分片(weight-update sharding)类似。
平行束搜索(flat beam search)解码
作者采用带长度归一化的束搜索进行解码,自回归解码,每次只产生一个token。每个解码器的MoE层都有分片/合并操作,需要跨设备通信。推理过程使用和训练时一样的集群。
但这不是一般的束搜索,它做了平行化,即将所有候选解逐个token进行交织,变成一个长序列以便在下一步中一起处理;同时也对解码器的自注意力掩码做了相应的调整,确保每个束假设只关心自身的部分;对每层的键值张量施加相同的转换,确保每次束假设扩展时能适应新的假设,不需要对键值张量重新排序;需要重新排序是指示当前活跃假设的0/1掩码。但耗时变成k倍。
好坏取决于实现细节。内存带宽对于Transformer模型增量解码的过程而言至关重要,通过平行化,将两个计算/存储比低的操作(注意力点积、键值重排)换成了略高的(在更长序列上、带更多键的注意力点积)。
(具体的实验参数有M=1024;H=8192;多头注意力中头数=16;注意力的键值维度=128)(还有好多其他的具体实验设置)(就这样吧~)