作者:你的真实姓名
知乎:https://www.zhihu.com/question/650979052/answer/3501160453
最近看到知乎一个回答,把千卡训练的难度吹上天了。但其实真正用过千卡就会发现也就那么几个点。于是想写一篇文章简单讲讲。
本文将包括3个部分:首先我们将讨论千卡训练的难题,以及应该在什么时候使用千卡训练;接着,我们将讨论如何在一千张卡上开始训练,如何让他达到近乎线性的性能提升;最后我们将展开讨论一些千卡训练当中仍然悬而未决(至少对于开源社区来说)的问题。
为什么千卡训练是困难的?
千卡训练和八卡训练的区别是—显卡多了一百多倍。
这意味着什么呢?
-
通信时间增加
-
故障概率增加
这俩问题都很好理解。
时间上,PyTorch内部支持NCCL/Gloo/MPI三个通信后端(请务必使用NCCL。其中AllReduce操作会会根据具体硬件配置走Ring AllReduce和Tree AllReduce。Ring的时间复杂度是,Tree的时间复杂度是 。就算是理论上128节点也比单节点慢至少七倍,实践当中跨节点通讯要远比单节点慢得多。
故障上,一个节点出问题的概率是p,128个节点就是1-(1-p)^128。也就是说如果一个操作在一个训练当中的出错概率是1%,那么在128节点当中的出错概率就是72.37%。
此外,随着规模的增大,许多问题都会变得难以忍受。比如数据增强要花0.1s,一亿条数据就是278个小时(当然这只是胡拆的一个数字,实际有各种机制所以不会有这么大影响。
因此,钱多烧手并不是使用千卡训练的理由。闲得蛋疼可能是,但你得多蛋疼才能想出这么折磨自己的idea?
千卡训练解决的问题是大模型&大数据问题。如果你的训练时间没有超过8192GPU日,那么你绝对不需要一千张显卡。
看到这里,绝大多数人已经可以关掉这篇文章了。除非你的模型和数据都以B(十亿)来作为计量单位。当然如果你正在厕所里手机没电想看点儿东西解闷儿的话(虽然我很怀疑是否会有人把他打出来……那么可以继续往下看
如何使用一千张卡训练?
如何提高计算效率?
这件事情其实是一个case by