GPT-2(XL)有 15 亿个参数,在 16 位精度下会消耗大约 3GB 的显存(在 16 位精度下,一个参数占用 2 字节的显存)。
在单个 GPU 上训练 GPT-2 所需的最小显存是多少?
- 优化器(Adam)
- 批大小(32)
- 层数(48)
- 序列长度(1000)
答案是?
- 4-6GB
- 8-10GB
- 12-15GB
- 32-35GB
- 50+GB
答案可能会令你大吃一惊。
在一个 16GB 显存的 GPU 上训练 GPT-2 几乎不可能。
咋一听,这怎么可能呢?显存都用到哪儿去了呢?
让我们来理一理。
实际上,模型在训练的过程中,很多地方都会持续的消耗显存。
优化器状态、梯度信息和模型参数
如何计算模型训练时的内存需求?模型计算显存主要由模型参数、优化器状态、梯度信息以及激活值等多方面因素共同决定。
混合精度训练常常被应用于加速模型训练。
顾名思义,混合精度就是同时利用较低精度的float16(如卷积和矩阵乘法)以及较高精度的float32。
在正向传播和反向传播中都使用 16 位来表示权重和梯度。因此,如果模型有 X 个参数,那么:
- 权重将消耗 2X 字节
- 梯度将消耗 2X 字节
此外,为了确保计算的有效性,在反向传播的最后仍然采用 32 位进行计算。如下图圆圈中所示:
Adam 优化器是模型训练中最受欢迎的优化器之一。
很多人仅仅因为它的流行,都采用 Adam 优化器。但他们没有意识到,采用 Adam 优化器后,在训练过程中,Adam 存储了两个优化器状态来计算更新:梯度的动量和方差。
因此,如果模型有 X 个参数,那么,这两个优化器状态消耗的显存如下:
- 动量:4X 字节(32 位)
- 方差:4X 字节(32 位)
最终更新都是以模型权重的 32 位表示进行调整,这导致另外 4X 字节用于模型参数。
算下来,目前共占用了
,也就是 24GB 显存,这比上面说的 3GB 高多了吧。
Activation
对于大型深度学习模型(如 LLM),在训练过程中Activation会占用大量内存。GPT-2 的一个 Transformer block 计算如下:
因此,所有 Transformer block 的计算如下:
下面是 GPT2-XL 的配置:
这相当于大约 30GB。由于每个Activation都以 16 位表示,因此所有Activation总共消耗大约 60GB 的显存。
使用梯度检查点技术,可能会降低到 8-9GB 左右,但代价是增加25-30%的运行时。这个技术将总显存消耗控制在了 32-35GB 左右,除此之外,还有一些额外的开销,比如内存碎片(当分配的存储块之间存在小的、未使用的间隙时,将会导致可用内存的低效使用)。
最后,来举个显存计算的例子。
假设有一个拥有700亿个参数的模型,使用float16精度进行训练,批大小为32,序列长度为512,隐藏层大小为4096,80 层,使用Adam fp32优化器。
模型参数内存 = 700亿 × 2字节 ≈ 140GB
激活内存(前向 + 反向)= 80层 × 288MB × 2 ≈ 45GB
-
每层的激活内存计算:
-
序列长度512 × 隐藏层大小4096 × batch size 32 = 67,108,864个元素
-
每个元素float16占2字节 → 134MB
-
考虑反向传播需要保存前向的激活值 → ×2 = 268MB
-
加上注意力机制等额外开销 ≈ 288MB/层
-
梯度内存 = 700亿 × 2字节 ≈ 140GB(梯度与参数量相同,数据类型为 float16)
Adam优化器需要维护:
-
一阶矩(动量,momentum):
float32
(4字节) -
二阶矩(方差,variance):
float32
(4字节) -
参数副本(可选):如果优化器使用
float32
,则还需存储float32
参数副本(4字节)
-
计算(使用
float32
优化器状态):-
动量 =
70B × 4字节 = 280GB
-
方差 =
70B × 4字节 = 280GB
-
参数副本(可选) =
70B × 4字节 = 280GB
-
总优化器状态内存 = 280GB + 280GB + 280GB = 840GB
-
缓冲区内存 = 框架开销(PyTorch/TensorFlow):4-8 GB,取 6GB
在 推理阶段,KV缓存用于存储 Key
和 Value
矩阵,避免重复计算:
-
每token每层 KV 缓存大小:
-
2 × hidden_size × 2字节 = 2×4096×2=16KB
-
-
batch=32, seq_len=512, 80层:
-
32 × 512 × 80 × 16KB ≈ 20GB
-
总内存 = 140GB + 45GB + 140GB + 840GB + 6 + 20 = 1191GB
额外开销 = 1191× 10% = 119GB
总计 = 1191 + 119 = 1310GB
推荐一个简便的可视化工具:https://rahulschand.github.io/gpu_poor/
还有一种粗略估算方法:参数量乘对应精度byte的基础上,再乘1.2的系数。
例如,一个 1B(10 亿参数)的 FP32 模型,单个参数占 4 字节,那么模型大小为:模型大小 = 参数数量 × 每个参数字节数 = 10 亿 × 4 字节 = 40 亿字节 ≈ 4GB。对于一个 14B 参数的 FP32 模型,模型大小为 14 × 4 = 56GB,乘以 1.2 后,运行所需显存约为 67.2GB。