TL:RL
传统PyTorch模型加载工作流:
- 创建模型
- 加载权重(state_dict)
- 加载到模型中
- 移动到设备进行推理
问题: 大模型在步骤1和2时需要大量的内存。
基于Accelerate的优化工作流:
- 创建空模型(无权重)
- 决定每层位置(设备映射)
- 加载部分权重到内存
- 加载到模型中
- 移动到设备进行推理
- 重复步骤3,直到所有权重加载完成
关键步骤:
- 使用
init_empty_weights()
创建不占用内存的模型。 - 计算
device_map
以优化各层的设备分配。 - 采用权重分片(sharding)减少内存占用。
- 使用PyTorch hooks简化多设备运行,自动管理权重加载与转移。
模型加载和推理
传统的PyTorch模型加载流程通常包括以下步骤:
- 创建模型
- 在内存中加载权重(通常称为state_dict)
- 将权重加载到模型中
- 将模型移动到设备进行推理
然而,对于超大模型,这种方式变得困难。例如,加载一个67亿参数的模型在步骤一的模型创建就需要约26.8GB的CPU RAM。第二步还会再加载一份模型备份,即还会再需要26.8GB的CPU RAM。而且,以上步骤只是为了能将模型在步骤4移动到GPU上。
torch.load的特性
- 基于 PyTorch 的模型加载使用
torch.load
函数。 - 在模型加载过程中,张量数据会首先在 CPU 上进行反序列化操作,并存储在 CPU 内存中。通过
map_location
参数,可以动态调整张量数据所在的设备位置,以便将其移动到合适的设备(如 GPU)。
基于Accelerate的模型加载和推理
接下来,我会介绍如何利用Accelerate优化PyTorch的特性,以便加载和推理非常大的模型,即使它们无法完全放入RAM或单个GPU。新的工作流如下:
- 创建一个空模型(没有权重)
- 决定每一层放置的位置(在多个设备可用时)
- 加载部分权重到内存
- 将这些权重加载到空模型中
- 将权重移动到设备进行推理
- 重复步骤3,直到所有权重加载完成
Creating an empty model
- Meta Device类型的张量不携带任何数据,只需要定义形状,而不必担心 CPU 或 GPU 的内存限制。
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig.from_pretrained("bigscience/bloom")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
Computing a device map
device_map
是一个用于指定模型各层或模块在不同设备(如GPU或CPU)上存放位置的字典。它的主要目的是优化内存使用,使得在加载和推理大型模型时,可以充分利用可用的硬件资源。
device_map = infer_auto_device_map(model)
Sharding state dicts
- 传统的PyTorch模型权重保存方式(state_dict)对于大模型非常占用内存。为了解决这个问题,Accelerate采用了分片(sharding)保存方式,将模型权重分成多个文件,每个文件只包含部分权重。这样可以逐步加载权重,减少RAM占用。同时,
from_pretrained
方法支持通过device_map
和offload_folder
选项管理权重的加载与存储,提高了大模型的可用性和效率。
Running a model split on several devices
- Accelerate通过使用PyTorch的hooks机制,简化了模型在多个设备(GPU、CPU和磁盘)上的运行。具体来说,在每次前向传播前后,dispatch_model函数会自动添加hooks,确保模块的输入与权重位于同一设备。
- 如果权重被卸载到CPU,它会在前向传播前将其移动到GPU上;
- 如果权重在磁盘上,则会先加载到RAM,再转移到GPU上。
- 这种方法有效管理了内存,提升了大模型的运行效率。
参考文献
- https://huggingface.co/blog/accelerate-large-models
- https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#how-the-process-works-working-with-code
- https://huggingface.co/docs/transformers/accelerate