涨知识了,怎么估算大模型的训练时长?

部署运行你感兴趣的模型镜像

 Datawhale分享 

作者:冰冰,知乎

作者:冰冰,一个着迷于数学的软件工程研究僧

来源:https://zhuanlan.zhihu.com/p/1941384271051854082

昨天睡前和朋友聊天,他讲到自己接到老板的一个需求,让他根据不同的数据集和 tokenizer 作为输入,去估计训练时长要多久。

看到这个问题,我激情输出了十来分钟,经过了大概两轮的输出,最终确定了该怎么去做这件事情。

下面就简单说说想法吧,顺便把自己的想法给完善一下,这边主要说的是 SFT,RL 的情形太复杂了,主包算不明白。

训练时长这件事情,可以抽象成一个非常简单的公式:

如果能够把分子分母的两个量全都计算出来,就可以把训练时长给估计出来。那问题就变成了怎么把这两个量给求出来。

此处就不去考虑各种诸如 TP,PP,DP,EP,CP 这些模型并行的策略了,我只是大概了解基于 All-Reduce、All-Scatter 和 All-Gather 的通信/并行概念,但并不知道这些并行方式底层的通信方式和所需的通信量。

而且,如果要考虑这些因素的话,可能还需要考虑单机内和多机之间的通信模式,完全在我的能力范围之外了。

把上面的那个数算出来,就已经能估计出训练时长的下界,将将够用。最后无非就是基于 PCIE 来做卡间通信会更慢点,用 NVLINK 连的话就会更快点。

一、显卡算力计算

首先看分母上的那个量,这个其实很简单,每秒的总算力就是:

这个东西还是很好算的,以单机 8 卡为例,在 FP16 的精度下,各种卡的 TFLOPS 大概是长这样的:

上面大概覆盖了主流的一些卡的 fp16 算力,查表就可以用。

另外,因为现在训练的时候通常使用混合精度训练(Mixed Precision),前向传播/反向传播的时候用 fp16,在用 Optimizer 做 gradient accumulation 的时候,通常要用 FP32 做累加。

为了考虑进这个因素,我问了下 GPT5,他给我的回复是:专业的计算卡做过优化,算 fp32 的时候,速度不会掉点;

但是像 5090 这样的消费级显卡,做 fp32 的 accumulation 时会掉一半的性能。【若 GPT5 回答的有闪失,谢谢大佬指正】这个在计算时要注意一下。

二、训练所需的总计算量

接着就要来计算训练所需的总计算量了,这个东西我琢磨了一下,大概会和以下几个东西有关系:Epoch 数,模型大小,数据集的 token 数还有 tokenizer 效率。

Epoch 数应该是当中最简单的因素,一个 Epoch 要多少的计算量,有了多个 Epoch 之后,直接乘以 Epoch 数就行了:

这个时候再对每 Epoch 的计算量做计算,在模型训练阶段所需要考虑的两个因素:

  • 数据集有多少 token;

  • 模型前向传播+反向传播所需计算量。

数据集有多少 token 这很好统计,主要还是要考虑第二个问题,对于这个问题,在 OpenAI 2020 年的文章 Scaling Law 里早有记载:

其中,6 这个系数可以拆成两个部分,前向传播占用了其中的 2 份,反向传播因为要两倍的计算量,所以占用了其中的 4 份。

用这个公式的话,就不用再考虑很多零碎的细节,像 hidden_size、head_nums 和 dim_head 统统不用考虑,拿来就用,非常开心。

GPT5 告诉我,GPT3 和 Chinchilla 等工作都是用这个来估算的。

这个东西算完以后,就要考虑一个问题,怎么根据数据集来确认总 token 数呢?

这个数据完全是跟着 tokenzier 来的,但又不可能为了算一个数据集就去对着数据集 tokenize 一遍,这个太费劲了。

所以,根据上面的套路,这个地方适合用某种方法预处理出来一个量,我打算管这个玩意叫 tokenizer 效率,衡量了 tokenizer 的压缩率;可以是 token per 每个字,也可以是 token per byte。

总之把这个数据给估算出来,就可以拿这个值直接去算一个数据集可能会分成多少个 token 了。

因为每家的用自己的语料做出来的 tokenizer 不太一样,所以需要预处理一下每个 tokenizer 的效率。

这里呢,我大概想到了两种做法:

  • 第一种,这种就比较粗暴,直接用一批数据去做 tokenize,把 tokenize 完得到的 token 数去除以字符串长度/byte 数,就可以得到 tokenizer 的效率;

  • 第二种,如果希望估计得更加精细的话,可以用线性回归去拟合一个值出来,其中 是句子的 token 数量,是句子长度/byte 数。

至于为啥我在这边要老提 byte 数,因为不同的字符对应的 byte 数是不一样的,英文字符的话使用 ASCII 码,一般就 1 个字节,但是像中文单字的话,一般会使用 3-4 个字节来表示。

基于 BPE 的 tokenizer,是预处理了在一批语料中最常见的 Byte Pair 作为词表去做 tokenize 的。

这也就导致了像中文这样的单字,有时可能会被拆成好几个 token,看过词表的 uu 们应该都知道,在那个表里面是找不到中文单字的,而且经常能看到一些长相非常奇怪的字符串。

我觉得,从 byte 层面进行估算,会更精确些(大概)。至于用来算 tokenizer 效率的这样一个数据集,多样性要强一点,尽量把整个词表都覆盖一下。

不过,你也可以完全当我前面是在放 X,直接用经验法则去估计:1个英文字母差不多等于 3/4 个 token,一个中文字约等于 1 个 token。

到这边,整个估算流程就差不多齐活了,整个过程可以浓缩成下面这样一个公式:

反正挨个把数据套进去,训练时长的下界就可以被估算出来了。很方便啊,很方便。估算过程就到这边为止了。

三、幺蛾子想法x1

然后,在思考这个问题问题的时候,我回想起前天在看 Kimi K2 技术报告时突发奇想到的一个幺蛾子问题:

是否可以通过这份开源的技术报告里公布的训练细节去反推出厂商能够多少张卡?

这个问题我觉得还是非常有意思的,当时想搁置一下慢慢研究的;我当时的思路就是,可能需要根据开源报告里提出的各种并行策略,参数设置去推断。

但是,当我昨天把上面那个问题琢磨出来以后,就发现一个更加简单的估算方法:报告里通常会把 Epoch 数,模型参数,还有训练的 token 总量放出来。

如果他还大发慈悲地把训练时长也公布出来,那直接除一除不就可以得到集群的算力了嘛。

至于在 PP 之间可能会产生气泡,我看他们都在尽量压缩气泡的大小,那直接当作没气泡的状况去估计得了,糙一点也不是什么坏事。

至于这个东西有啥用,我刚想到一个比较有意思的用途。可以去看一个大厂里不同的组做出来的模型都用到了多少张卡,也可以间接估计这个组在公司里的地位怎么样,对于找工作的各位可能有一定的参考价值罢(x

这边就先写到这里,之后可能会密集更新很多自己的幺蛾子想法和学习心得,我发现和评论区的网友 battle 能够帮助我去思考一些平时没有思考到的问题,在 battle 的过程中还是能学到很多的东西的,输出为王!

感谢你看到这儿 ^^

图片

一起“赞”三连

您可能感兴趣的与本文相关的镜像

Wan2.2-I2V-A14B

Wan2.2-I2V-A14B

图生视频
Wan2.2

Wan2.2是由通义万相开源高效文本到视频生成模型,是有​50亿参数的轻量级视频生成模型,专为快速内容创作优化。支持480P视频生成,具备优秀的时序连贯性和运动推理能力

内容概要:本文详细介绍了一种基于Simulink的表贴式永磁同步电机(SPMSM)有限控制集模型预测电流控制(FCS-MPCC)仿真系统。通过构建PMSM数学模型、坐标变换、MPC控制器、SVPWM调制等模块,实现了对电机定子电流的高精度跟踪控制,具备快速动态响应和低稳态误差的特点。文中提供了完整的仿真建模步骤、关键参数设置、核心MATLAB函数代码及仿真结果分析,涵盖转速、电流、转矩和三相电流波形,验证了MPC控制策略在动态性能、稳态精度和抗负载扰动方面的优越性,并提出了参数自整定、加权代价函数、模型预测转矩控制和弱磁扩速等优化方向。; 适合人群:自动化、电气工程及其相关专业本科生、研究生,以及从事电机控制算法研究与仿真的工程技术人员;具备一定的电机原理、自动控制理论和Simulink仿真基础者更佳; 使用场景及目标:①用于永磁同步电机模型预测控制的教学演示、课程设计或毕业设计项目;②作为电机先进控制算法(如MPC、MPTC)的仿真验证平台;③支撑科研中对控制性能优化(如动态响应、抗干扰能力)的研究需求; 阅读建议:建议读者结合Simulink环境动手搭建模型,深入理解各模块间的信号流向与控制逻辑,重点掌握预测模型构建、代价函数设计与开关状态选择机制,并可通过修改电机参数或控制策略进行拓展实验,以增强实践与创新能力。
根据原作 https://pan.quark.cn/s/23d6270309e5 的源码改编 湖北省黄石市2021年中考数学试卷所包含的知识点广泛涉及了中学数学的基础领域,涵盖了实数、科学记数法、分式方程、几何体的三视图、立体几何、概率统计以及代数方程等多个方面。 接下来将对每道试题所关联的知识点进行深入剖析:1. 实数与倒数的定义:该题目旨在检验学生对倒数概念的掌握程度,即一个数a的倒数表达为1/a,因此-7的倒数可表示为-1/7。 2. 科学记数法的运用:科学记数法是一种表示极大或极小数字的方法,其形式为a×10^n,其中1≤|a|<10,n为整数。 此题要求学生运用科学记数法表示一个天文单位的距离,将1.4960亿千米转换为1.4960×10^8千米。 3. 分式方程的求解方法:考察学生解决包含分母的方程的能力,题目要求找出满足方程3/(2x-1)=1的x值,需通过消除分母的方式转化为整式方程进行解答。 4. 三视图的辨认:该题目测试学生对于几何体三视图(主视图、左视图、俯视图)的认识,需要识别出具有两个相同视图而另一个不同的几何体。 5. 立体几何与表面积的计算:题目要求学生计算由直角三角形旋转形成的圆锥的表面积,要求学生对圆锥的底面积和侧面积公式有所了解并加以运用。 6. 统计学的基础概念:题目涉及众数、平均数、极差和中位数的定义,要求学生根据提供的数据信息选择恰当的统计量。 7. 方程的整数解求解:考察学生在实际问题中进行数学建模的能力,通过建立方程来计算在特定条件下帐篷的搭建方案数量。 8. 三角学的实际应用:题目通过在直角三角形中运用三角函数来求解特定线段的长度。 利用正弦定理求解AD的长度是解答该问题的关键。 9. 几何变换的应用:题目要求学生运用三角板的旋转来求解特定点的...
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值