PyTorch Float8 实验项目教程
1. 项目介绍
float8_experimental
是一个实验性的 PyTorch 库,旨在通过使用 float8 数据格式加速训练过程。该项目遵循 arXiv 论文 中提出的方法,致力于在 PyTorch 中实现 float8 训练的本地支持。该库的目标是保持代码简洁、易于调试,并与 PyTorch 的核心系统(如 autograd、torch.compile 和分布式计算)兼容。
2. 项目快速启动
安装
首先,确保你使用的是最新的 PyTorch nightly 版本,以获得最佳的 torch.compile
支持。
pip install torch --pre
然后,安装 float8_experimental
库:
pip install git+https://github.com/pytorch-labs/float8_experimental.git
示例代码
以下是一个简单的单 GPU 训练示例,使用动态缩放策略:
from float8_experimental import convert_to_float8_training
# 创建模型
m = Model()
# 将所有 `torch.nn.Linear` 模块转换为 `Float8Linear`
convert_to_float8_training(m)
# 启用 torch.compile 以提高性能
m = torch.compile(m)
# 训练循环
for _ in range(N_ITER):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
3. 应用案例和最佳实践
动态缩放 vs. 延迟缩放
- 动态缩放:每个张量动态调整其缩放因子,适用于需要高精度的场景。
- 延迟缩放:在训练过程中延迟缩放因子的计算,适用于需要高性能的场景。
多 GPU 支持
float8_experimental
与 DTensor 分布式 API(如 FSDP、TP 和 SP)兼容。以下是一个使用 FSDP 的多 GPU 示例:
from float8_experimental import convert_to_float8_training
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# 创建模型
m = Model()
# 转换为 float8 训练
convert_to_float8_training(m)
# 使用 FSDP
model = FSDP(model, use_orig_params=True)
# 启用 torch.compile
m = torch.compile(m)
# 训练循环
for _ in range(N_ITER):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
4. 典型生态项目
PyTorch 生态系统
- torch.compile:用于优化模型性能。
- torch.distributed:用于分布式训练。
- torch.autograd:用于自动微分。
相关项目
- torch.ao:PyTorch 的自动优化库,包含
float8_experimental
的最新版本。 - torch.titan:用于大规模分布式训练的示例和工具。
通过结合这些工具和库,float8_experimental
可以在各种训练场景中提供显著的性能提升。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考