这一章节虽然是BMTrain,不是目前常用的Megatron+DeepSpeed,但是对于了解原理,也是很有帮助。
BMTrain
数据并行
一般数据并行
上图,把数据切为3份,每张显卡处理一部分数据,每张显卡利用得到的数据进行前向传播和反向传播,得到各自的梯度,为了让模型学到这份数据的所有知识,就需要把这些梯度的信息进行一个聚合,也就是取平均的操作。那么我们得到聚合好的参数去更新模型,就能学到切分3部分数据合起来的完整的数据的知识。
具体来说:在数据并行的过程中,我们有一个参数服务器,里面保存了模型的参数,还有完整的一批数据,前向传播的过程中,参数服务器上的参数,会被复制到所有的显卡上,每张显卡上都得到了很参数服务器上一样的模型参数,然后把数据切分为3份,每张显卡上各拿到一部分数据,然后每张显卡用完整的模型参数和一部分的数据去进行前向传播和反向传播,我们就能够得到每张显卡上各自的梯度,最后将这个梯度进行聚合,将聚合后的梯度传回我们的参数服务器。那么参数服务器上面有了原始的模型参数和这个聚合好的模型的完整的梯度,我们就可以用优化器去对模型的参数进行更新,那么更新后的参数又会进入下一轮的模型的训练迭代
集合通信
- broadcast
把数据从一张显卡传到其它所有的显卡上 - reduce
规约可以是求和、平均、MAX等,把各张显卡上的数据进行一个规约,然后把规约得到的结果,放到我们其中一张指定的显卡里面。 - all reduce
和reduce几乎相同,不同点就是把结果发送到所有显卡上面(reduce是发送到指定的一张显卡上) - reduce scatter
跟all reduce相比,相同之处是它们都把规约得到的结果发给所有显卡,不同之处在于,reduce scatter,最后每张显卡上只得到一部分的规约结果。
如上图,0号显卡,会得到in0的前1/4的参数,加上in1的前1/4的参数,加上in2的前1/4的参数,加上in3的前1/4的参数,得到out0。其他同理。 - all gather
可以跟all reduce类比下,把各张显卡上的数据进行收集,然后进行一个拼接,然后广播到所有显卡上,所有显卡得到了一个搜集后的结果。
分布式数据并行
对数据并行进行了优化,没有参数服务器。
每张显卡各自完成参数更新。然后保证参数更新后的结果一致。
数据并行显存优化
计算过程的中间结果是跟batch乘以句子长度和模型维度相关的一个显存占用。使用数据并行的时候,把一批数据分成了N份,让每张显卡只处理其中的一部分数据,等效于每张显卡上所处理的batch的大小将变成了原来显卡数分之一,那么通过把这个输入的维度进行了一个降低,那我们模型整体的中间结果量(每张卡),也进行了一个降低。
但是这个方法有一个缺点,就是为了支持模型的训练,它至少需要训练1条数据,最极端情况下,每张显卡只得到一个数据的时候,由于我们的参数哈需要完整的保存在显卡上,梯度也需要完整保存在显卡上,优化器也需要完整保存在显卡上,那么即使中间结果一点都不在显卡上,我们模型仍然有可能无法在一张显卡上进行计算。
模型并行
一张显卡无法放下模型的所有参数、所有梯度、所有优化器,就想办法把一个模型分成很多个小部分
思路:针对线性层的矩阵乘法的例子,上图左上角,3行2列的矩阵乘以2行1列的矩阵,本质上可以把它的结果分成三个 部分。
可以把3X2的矩阵看成线性层的参数W,把2X1向量看成线性层的输入。
通过上面的方法,我们的线性层的参数,就可以划分到多张显卡上。而且我们需要保证多张显卡上模型的输入是一样的。那么我们就不能使用数据并行的那一套方式对数据进行划分了,我们需要保证每一张显卡上得到的输入是一样的,所以他们是同一批数据。