分析transformer模型的参数量、计算量、中间激活、KV cache、bf16、fp16、混合精度训练

难得一遇的好文,转载自https://zhuanlan.zhihu.com/p/624740065,先做个简单总结:

训练时的参数量由模型参数后向的梯度优化器参数前向的中间激活组成:

假设采用Zero3,模型参数量为B,混合精度训练,卡数为N:

  • 优化器参数、后向的梯度、模型参数是DeepSpeed中分的参数,每块卡分20B/N,关于20B怎么来的可以参考 Deepspeed、ZeRO、FSDP、ZeRO-Offload、all reduce、reduce-scatter简单记住20B/N或者16B/N即可,假设8卡A100训练一个1B的模型,这部分每张卡占用4G*20/8=10G左右,你的A100里的10个G就没有了。
  • 前向的中间激活才是真的大头。中间激活占用的显存大小可以近似为:
    ( 34 b s h + 5 b s 2 a ) ⋅ l (34bsh + 5bs^2a) \cdot l (34bsh+5bs2a)l,其中b是batch_size,s是序列长度,l是模型层数,a是有几个头,如果transformer模型的参数量近似为 12 l h 2 12lh^2 12lh2,可以直观感受一下比例。算的时候需要计算器没法口算,但是要知道中间激活是和batch_size 100%成正比和模型头数是一部分成正比,和序列长度一部分正比,另外一个部分平方成正比。以GPT3 2048序列长度感受一下,bs=1是中间激活是模型大小的0.79倍。而且中间激活是每个卡独占算的时候不能像上面Zero3一样给除N。这就是为啥8卡Finetune一个7B的模型,序列长度是4096,如果bs=1时,可能每张卡大概中间激活大小就超过模型的2倍,也就是至少占了50多G的显存,bs进一步大的时候,就直接OOM了。

模型参数

假设Transformer的hidden_size是h,那么总的一层transformer层的参数量为 12 h 2 + 13 h 12h^2+13h 12h2+13h ,l层的transformer计算量就是 l ∗ ( 12 h 2 + 13 h ) l*(12h^2+13h) l(12h2+13h):

  • 多头注意力四个参数矩阵Q、K、V、O和他们的bias对应 4 h 2 + 4 h 4h^2+4h 4h2+4h
  • position wise feedforward对应 4 h 2 + 4 h + 4 h 2 + h 4h^2+4h+4h^2+h 4h2+4h+4h2+h
  • 多头注意力和position wise feedforward各有一个LayerNorm对应2个可训练模型缩放参数beta和平移参数gamma参数是4h

前向计算过程中产生的中间激活

前向计算过程中产生的中间激活,中间激活值与输入数据的大小(批次大小b和序列长度 l)是成正相关的,随着批次大小b和序列长度l的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存。

前向的中间激活才是真的大头。中间激活占用的显存大小可以近似为:
( 34 b s h + 5 b s 2 a ) ⋅ l (34bsh + 5bs^2a) \cdot l (34bsh+5bs2a)l,其中b是batch_size,s是序列长度,l是模型层数,a是有几个头,如果transformer模型的参数量近似为 12 l h 2 12lh^2 12lh2,可以直观感受一下比例。算的时候需要计算器没法口算,但是要知道中间激活是和batch_size 100%成正比,和模型层数100%成正比和模型头数是一部分成正比,和序列长度一部分正比,另外一个部分平方成正比

可以看到随着批次大小 b b b 的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用 激活重计算技术来减少中间激活,理论上可以将中间激活显存从 O ( n ) O(n) O(n) 减少到 O ( n ) O(\sqrt{n}) O(n ),代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。

后向传递计算得到的梯度、优化器状态

在一次训练迭代中,每个可训练模型参数都会对应1个梯度,并对应2个优化器状态(Adam优化器梯度的一阶动量和二阶动量)。设模型参数量为 Φ \Phi Φ,那么梯度的元素数量为 Φ \Phi Φ,AdamW优化器的元素数量为 2 Φ 2\Phi 。float16数据类型的元素占2个bytes,float32数据类型的元素占4个bytes。在混合精度训练中,会使用float16的模型参数进行前向传递和后向传递,计算得到float16的梯度;在优化器更新模型参数时,会使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。因此,对于每个可训练模型参数,占用了 ( 2 + 4 ) + ( 2 + 4 ) + ( 4 + 4 ) = 20 b y t e s (2+4)+(2+4)+(4+4)=20bytes (2+4)+(2+4)+(4+4)=20bytes。使用AdamW优化器和混合精度训练来训练参数量为 Φ \Phi Φ的大模型,模型参数、梯度和优化器状态占用的显存大小为 20 Φ 20\Phi 20Φ

推理时的参数量少了梯度、优化器状态、中间激活,但多了kvcache

假设输入序列的长度为s,输出序列的长度为n,以float16来保存KV cache,那么KV cache的峰值显存占用大小为 b ∗ ( s + n ) h ∗ l ∗ 2 ∗ 2 = 4 l h ( s + n ) b*(s+n)h*l*2*2 =4lh(s +n) b(s+n)hl22=4lh(s+n)。这里第一个2表示K/V cache,第个2表示float16占2个bytes。
以GPT3为例,对比KV cache与模型参数占用显存的大小。GPT3模型占用显存大小为350GB。假设批次大小b=64,输入席列长度 =512,输出序列长度n =32,则KV cache占用显存大约为 46 l h ( s + n ) 46lh(s + n) 46lh(s+n)= 164,282,499,072bytes约等于164GB,大约是模型参数显存的0.5倍

训练时的计算量

FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小

对于 A ∈ R m × n A \in R^{m \times n} ARm×n, B ∈ R n × p B \in R^{n \times p} BRn×p,矩阵A✖️B浮点运算需要mnp次乘法和mnp次加法,因此FLOPs为2mnp。这个刚好是计算复杂度*2的关系,因此我们下面只估计乘法的计算次数,最后算FlOPs乘2就好了

模型计算复杂度

假设Transformer的hidden_size是h,序列长度是s,batch大小是b,那么总的一层transformer层的参数量为 12 b s h 2 + 2 b s 2 h 12bsh^2+2bs^2h 12bsh2+2bs2h ,l层的transformer计算量就是 l ∗ ( 12 b s h 2 + 2 b s 2 h ) l*(12bsh^2+2bs^2h) l(12bsh2+2bs2h):

  • 多头注意力四个参数矩阵Q、K、V、O对应 4 b s h 2 4bsh^2 4bsh2(忽略他们的bias)
  • QK、KV的复杂度是 2 b s 2 h 2bs^2h 2bs2h [ b , s , h ] ∗ [ b , h , s ] = [ b , s , s ] [b,s,h]*[b,h,s]=[b,s,s] [b,s,h][b,h,s]=[b,s,s] [ b , s , s ] ∗ [ b , s , h ] = [ b , s , h ] [b,s,s]*[b,s,h]=[b,s,h] [b,s,s][b,s,h]=[b,s,h],这俩都是 b s 2 h bs^2h bs2h
  • position wise feedforward对应 4 b s h 2 + 4 b s h 2 4bsh^2+4bsh^2 4bsh2+4bsh2

计算量与参数量的关联

当隐藏维度h比较大,且远大于序列长度s时,我们可以忽略一次项,计算量可以近似为 ( 12 b s h 2 + 2 b s 2 h ) ∗ 2 12 h 2 + 13 h ≈ 2 \frac{(12bsh^2+2bs^2h)*2}{12h^2+13h}\approx2 12h2+13h(12bsh2+2bs2h)22 。我们可以近似认为:在一次前向传递中,对于每个token,每个模型参数,需要进行2次浮点数运算,即一次乘法法运算和一次加法运算。

一次训练迭代包含了前向传递和后向传递,后向传递的计算量是前向传递的2倍(loss.backward()和optimizer.step()两步)。因此,前向传递 + 后向传递的系数 = 1+2 =3 。一次训练迭代中,对于每个token,每个模型参数,需要进行 2*3=6次浮点数运算。

训练时间估计

模型参数量和训练总tokens数决定了训练transformer模型需要的计算量。给定硬件GPU类型的情况下,可以估计所需要的训练时间。给定计算量,训练时间(也就是GPU算完这么多flops的计算时间)不仅跟GPU类型有关,还与GPU利用率有关。计算端到端训练的GPU利用率时,不仅要考虑前向传递和后向传递的计算时间,还要考虑CPU加载数据、优化器更新、多卡通信和记录日志的时间。一般来讲,GPU利用率一般在0.3~0.55之间。

上文讲到一次前向传递中,对于每个token,每个模型参数,进行2次浮点数计算。使用激活重计算技术来减少中间激活显存(下文会详细介绍)需要进行一次额外的前向传递,因此前向传递 + 后向传递 + 激活重计算的系数=1+2+1=4。使用激活重计算的一次训练迭代中,对于每个token,每个模型参数,需要进行2*4=8次浮点数运算。在给定训练tokens数、硬件环境配置的情况下,训练transformer模型的计算时间为:
在这里插入图片描述

fp16和bf16的相关问题

为什么有了fp16又整了一个bf16

  • BF16 的设计思想是在不改变内存占用的情况下,用1/10的精度换取了 10^34倍的值域。
  • BF16 只有 2bytes 的内存,但其数值范围与 4bytes 的 FP32 相同。在深度学习领域,数值范围的作用远高于数值精度;即数据类型的指数位的作用大于尾数位的作用。采用SGD的网络权重的更新方式为 θ t = θ t − 1 − η ∗ g t \theta_t=\theta_{t-1}-\eta*g_t θt=θt1ηgt。由于梯度 g t g_t gt和学习率 η \eta η通常较小,往往会出现很小的数值,因此必须使用能够表达较大范围的数据类型。使用 FP16时往往会出现 underflow(下溢)的情况,即小于 − 6.55 × 10^4 的数值被截断为0 ,导致梯度无法更新。所以 BF16 具有比 FP16 更优异的性能。
  • BF16 与 FP32 的转换很容易。使用混合精度计算时,需要频繁得对 BF16/FP32 和 FP32 进行转换。BF16 基本上可以看作成一个“截断”版的 FP32, 两者之间的转换是非常直接,其实现电路也会非常简单。相比于 FP16,BF16的使用能有效的降低电路的面积。

原文链接:https://blog.youkuaiyun.com/qq_43799400/article/details/134182459

如果还是v卡,继续用fp16做混合精度训练,大致的流程如何?

MIXED PRECISION TRAINING 来自于论文https://arxiv.org/pdf/1710.03740.pdf,
整体流程:FP32权重 -> FP16权重 -> FP16计算前向 -> FP32的loss,扩大 -> 转为FP16 -> FP16反向计算梯度 -> 缩放为FP32的梯度更新权重
在这里插入图片描述
值得注意的是:

  1. Loss scaling:由于梯度值都很小,用FP16会下溢,因此先用FP32存储loss并放大,使得梯度也得到放大,可以用FP16存储,更新时变成FP32再缩放。如果梯度过大导致FP16上溢,Nvidia Apex,或者Pytorch1.6 AMP都会跳过step()然后调整Loss的放大倍数。
  2. 在涉及到累加操作时,比如BatchNorm、Softmax,需要用FP32保存,一般使用GPU中TensorCore的FP16*FP16+FP32=FP32运算。这里是用FP32进行存储的主要原因应该是为了避免精度损失,因为如果直接在FP16上进行累加,可能会出现舍入误差(毕竟FP16在每一个数值区间能够表示的最小间隔比FP32的大)。

原文链接:https://zhuanlan.zhihu.com/p/110278004

inference的时候哪些算子不能转换fp16?

当然还是刚才说到的涉及到累加操作时,比如BatchNorm、Softmax,尤其是BatchNorm,这里引用来自apex的一个issue(https://github.com/NVIDIA/apex/issues/122),当然这这个issue也是说的训练时候怎么转,【Quantization aware training(神经网络量化训练)】https://www.bilibili.com/video/BV13s4y1D73L?vd_source=e260233b721e72ff23328d5f4188b304 里面强调做一些量化的时候要把BN算子给先合并掉


以下是转载截屏:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

<think>我们正在讨论的是参数 `kv_cache_dtype="fp8_e4m3"` 的含义。根据引用[1]、[2]和[3],我们可以得到以下信息: 1. **FP8 量化**:FP8(8位浮点数)是一种用于深度学习的低精度数值格式,旨在减少模型推理时的内存占用和提高计算效率。它有两种主流格式:E4M3(4位指数,3位尾数)和E5M2(5位指数,2位尾数)[^1][^2]。 2. **KV Cache**:在自回归语言模型(如Transformer)的解码过程中,为了高效生成token,通常会缓存键(Key)和值(Value)向量,避免重复计算。这个缓存称为KV Cache。 3. **`kv_cache_dtype` 参数**:指定了KV Cache中存储的数据类型。使用`"fp8_e4m3"`表示将KV Cache中的元素以FP8 E4M3格式存储。 4. **E4M3格式详解**: - **位分配**:1个符号位 + 4位指数 + 3位尾数(共8位)。 - **动态范围**:E4M3格式的表示范围约为$[-3.4\times10^4, 3.4\times10^4]$,但精度相对较低(尾数仅有3位,有效精度约为2位十进制数)[^1][^3]。 - **优势**:相比FP1616位浮点数),内存占用减少一半(从16位→8位),从而允许缓存更多上下文或更大的批次大小。 5. **计算过程**:虽然KV Cache使用FP8存储,但在计算注意力分数时,通常需要将FP8数据转换回更高精度(如FP16BF16)进行计算,以避免累积误差。然而,根据引用[3],TensorRT-LLM在特定条件下(如使用特定编译选项)可以直接用FP8参与注意力计算[^3]。 6. **实际效果**: - **内存节省**:KV Cache内存占用减少50%[^2]。 - **吞吐量提升**:引用[2]指出,使用FP8量化可将吞吐量提升最多1.6倍[^2]。 - **精度影响**:对模型输出质量影响极小(通常在1%以内)。 ### 技术细节补充 - **为什么选择E4M3?** E4M3相比E5M2有更高的精度(尾数多1位),适合对精度要求更高的KV Cache数据[^1]。 - **硬件要求**:仅在支持FP8计算的GPU上有效(如NVIDIA H100、L40S,AMD MI300X)[^2]。 - **与权重量化的区别**:此参数仅针对KV Cache模型权重仍可为原始精度(如FP16)或独立量化。 $$ \text{内存节省比例} = 1 - \frac{\text{FBITS}}{16} = 1 - \frac{8}{16} = 50\% $$ > 注:实际内存节省可能因对齐填充而略低于理论值。 ### 使用示例(伪代码) ```python from vllm import LLM, SamplingParams # 在支持FP8的GPU上加载模型,并指定KV Cache数据类型 llm = LLM( model="qwen1.5-7b", kv_cache_dtype="fp8_e4m3" # 关键配置 ) # 正常执行生成 outputs = llm.generate( ["人工智能的定义是"], SamplingParams(max_tokens=100) ) ``` ### 注意事项 1. **硬件兼容性**:当前仅限新架构GPU(如H100)[^2]。 2. **精度验证**:敏感任务(如金融、医疗)需测试输出质量。 3. **框架支持**:需vLLM≥0.4.0或TensorRT-LLM≥0.9.0[^2][^3]。 §§相关问题§§ 1. 如何检查当前GPU是否支持FP8计算? 2. 除了KV Cache,还有哪些组件可以使用FP8量化? 3. E4M3和E5M2格式在哪些场景下各有优劣? </think>### 参数 `kv_cache_dtype="fp8_e4m3"` 的含义详解 该参数指定键值缓存(KV Cache)的数据类型为 **FP8 E4M3 格式**,属于高性能推理中的显存优化技术。以下是分层解析: --- #### 1. **核心概念:KV Cache 的作用** 在 Transformer 架构(如 LLM)推理过程中: - **键(Key)和值(Value)张量** 需缓存以加速自回归生成 - 传统方案使用 FP16/BF1616位浮点),每 token 需约 2 字节显存[^1][^2] - `kv_cache_dtype` 通过降低精度减少显存占用 --- #### 2. **FP8 E4M3 的技术细节** - **数据类型**:8位浮点数(FP8) - **位分配**: - `E4`:4位指数(动态范围 $[-24, 15]$) - `M3`:3位尾数(精度 ≈ 2位十进制数) - 1位符号位 - **数值范围**:约 $±3.4 \times 10^4$(对比 FP16 的 $±6.5 \times 10^4$)[^1][^3] - **内存优势**: $$ \text{显存节省} = 1 - \frac{\text{FP8位数}}{\text{FP16位数}} = 1 - \frac{8}{16} = 50\% $$ --- #### 3. **硬件与计算优化** - **适用硬件**:专为 NVIDIA H100 / AMD MI300X 等支持 FP8 指令集的 GPU 设计[^2] - **计算过程**: 1. 原始 FP16 激活值经量化模块转为 FP8 2. FP8 格式数据存入 KV Cache 3. 计算时部分内核(如 FMHA)仍用 BF16 精度[^3] 4. 输出前反量化还原精度 - **性能收益**(基于引用[2]): - 吞吐量提升 ≤ 1.6 倍 - 显存需求降低 2 倍 - 精度损失通常 <1%(对多数生成任务无感)[^2] --- #### 4. **实际应用场景** ```python # vLLM 中的使用示例 from vllm import LLM, SamplingParams llm = LLM( model="Qwen3-1.7B", kv_cache_dtype="fp8_e4m3" # 关键配置 ) outputs = llm.generate(prompts, sampling_params) ``` - **适用场景**: - 长上下文推理(如 128K tokens) - 多任务批处理(batch_size > 16) - 边缘设备部署(显存受限环境) --- #### 5. **与其他参数的协同** - **需配合量化方法**:如权重使用 FP8 量化(引用[1]) - **与并行策略关联**: - 当启用张量并行(TP)时,KV Cache 分片存储 - 显存节约效果随并行度增大而提升(引用[4]): $$ \text{总显存需求} \propto \frac{\text{原始显存}}{ \text{TP} \times \text{量化系数} } $$ --- ### 总结 `kv_cache_dtype="fp8_e4m3"` 是通过 **将键值缓存精度降至 FP8(E4M3 格式)** 来优化推理的技术: 1. **显存减半**:从 FP1616 位 → FP8 的 8 位 2. **吞吐提升**:最高 1.6 倍加速(H100 GPU)[^2] 3. **精度可控**:尾数截断对语言模型输出影响微弱 4. **硬件绑定**:需新一代 GPU 支持 FP8 指令集 > 📌 注意:实际效果需测试验证,敏感任务(如数学推理)建议评估精度损失[^3]。
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值