1. FP8混合精度训练
(1) 定义
FP8混合精度训练是一种深度学习训练优化技术,利用 8位浮点数(FP8) 表示部分模型参数和计算结果,同时结合更高精度(如 FP16 或 FP32)进行关键计算,从而在保证模型精度的前提下显著降低计算成本和内存占用。
FP8(8-bit Floating Point)是一种新兴的数值表示格式,具有更低的存储需求和计算复杂度。相比传统的 FP32(32位浮点数)和 FP16(16位浮点数),FP8 的表示范围更小,但通过混合精度训练策略,可以在不显著损失模型性能的情况下,提升训练效率。
2. 核心功能
(1) 减少内存占用
- 功能:FP8 格式的数值表示仅需 8 位存储空间,相比 FP32 减少了 75% 的内存需求。
- 实现方式:
- 将部分模型参数(如激活值、梯度)存储为 FP8 格式。
- 在训练过程中动态调整 FP8 和更高精度格式之间的转换。
(2) 加速计算
- 功能:FP8 格式的计算复杂度更低,能够显著提升训练速度。
- 实现方式:
- 使用支持 FP8 运算的硬件(如 NVIDIA Hopper GPU)加速矩阵乘法和卷积操作。
- 在关键计算(如梯度累积)中使用更高精度(FP16/FP32)以保证数值稳定性。
(3) 保证数值稳定性
- 功能:通过混合精度策略,在关键计算中使用更高精度,避免因 FP8 的有限表示范围导致的数值溢出或下溢。
- 实现方式:
- 在前向传播中使用 FP8 表示激活值。
- 在反向传播中使用 FP16 或 FP32 计算梯度。
(4) 提升硬件利用率
- 功能:FP8 格式的低存储需求和高计算效率能够更好地利用硬件资源。
- 实现方式:
- 在支持 FP8 的硬件上(如 NVIDIA Hopper GPU)充分利用其专用的 Tensor Core 加速器。
3. 技术要素
(1) FP8 数值格式
- 核心思想:FP8 是一种 8 位浮点数格式,通常分为两种变体:
- E4M3:4 位指数,3 位尾数,适用于动态范围较大的数据。
- E5M2:5 位指数,2 位尾数,适用于动态范围较小但需要更高精度的数据。
- 关键技术:
- 动态