作者:你的真实姓名
知乎:https://www.zhihu.com/question/650979052/answer/3501160453
最近看到知乎一个回答,把千卡训练的难度吹上天了。但其实真正用过千卡就会发现也就那么几个点。于是想写一篇文章简单讲讲。
本文将包括3个部分:首先我们将讨论千卡训练的难题,以及应该在什么时候使用千卡训练;接着,我们将讨论如何在一千张卡上开始训练,如何让他达到近乎线性的性能提升;最后我们将展开讨论一些千卡训练当中仍然悬而未决(至少对于开源社区来说)的问题。
为什么千卡训练是困难的?
千卡训练和八卡训练的区别是—显卡多了一百多倍。
这意味着什么呢?
-
通信时间增加
-
故障概率增加
这俩问题都很好理解。
时间上,PyTorch内部支持NCCL/Gloo/MPI三个通信后端(请务必使用NCCL。其中AllReduce操作会会根据具体硬件配置走Ring AllReduce和Tree AllReduce。Ring的时间复杂度是,Tree的时间复杂度是 。就算是理论上128节点也比单节点慢至少七倍,实践当中跨节点通讯要远比单节点慢得多。
故障上,一个节点出问题的概率是p,128个节点就是1-(1-p)^128。也就是说如果一个操作在一个训练当中的出错概率是1%,那么在128节点当中的出错概率就是72.37%。
此外,随着规模的增大,许多问题都会变得难以忍受。比如数据增强要花0.1s,一亿条数据就是278个小时(当然这只是胡拆的一个数字,实际有各种机制所以不会有这么大影响。
因此,钱多烧手并不