大模型分布式训练的处理如下:
- 分布式环境初始化;
- 模型加载到CPU;
- 模型训练数据分布式加载
- 模型分片初始化;
- 模型分片加载到对应GPU中
- 模型分布式训练
- 模型保存
下面我们以T5模型为例, 详细说明一下如何通过FSDP分布式框架利用多GPU对大模型进行分布式训练。FSDP框架是一个可以进行数据并行和模型并行的分布式训练框架,但如果要实现大模型多机多卡的分布式训练,就需要结合RPC分布式计算框架来实现,这个会在后续的分布式训练文章中进行具体介绍。
1. 大模型环境初始化:
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
2. 大模型单机多卡模型训练:
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank) # 作用是将每个GPU上的loss 进行汇总,形成一个全局的loss值
if sampler:
sampler.set_epoch(epoch) # 设置数据的随机种子,确保各个进程之间获得的数据是不同的,避免重复训练和过拟合
if rank==0: # rank=0 是主进程
inner_pbar = tqdm.tqdm(range(len(train_loader)), colour="blue", desc="r0 Training Epoch")
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank) # 将当前批次的数据分载到对应的GPU上
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) # 将不同进程中同一位置张量,进行规约(reduction)操作,从而达到全局梯度同步更新
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}")
return train_accuracy
3. 大模型单机多卡模型验证:
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch")
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) #将不同进程中同一位置张量,进行规约(reduction)操作,从而达到全局梯度同步更新
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
4. 大模型单机多卡模型训练的数据并行加载和模型并行加载:
def fsdp_main(args):
model, tokenizer = setup_model("t5-base") # 如果模型过大,建议先加载到CPU上,以免内存溢出
local_rank = int(os.environ['LOCAL_RANK']) # 获取当前进程的环境变量,即当前进程在整个分布式并行环境中的位置
rank = int(os.environ['RANK']) # 获取当前进程在分布式训练中的全局rank, 可以确定进程在分布式训练中的位置和作用,从而协同其它进程来进行分布式训练
world_size = int(os.environ['WORLD_SIZE']) # 获取当前分布式训练中的进程总数
dataset = load_dataset('wikihow', 'all', data_dir='data/')
print(dataset.keys())
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
# 数据并行预处理,返回每个进程需要处理的数据的索引
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
setup() # 分布式环境初始化
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False} # 配置数据加载器,提高数据加载效率
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)
#init_start_event = torch.cuda.Event(enable_timing=True)
#init_end_event = torch.cuda.Event(enable_timing=True)
#init_start_event.record()
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
if bf16_ready:
mp_policy = bfSixteen
else:
mp_policy = None # defaults to fp32
# model is on CPU before input to FSDP
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
#sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device())
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "T5-model-"
if rank == 0:
time_of_run = get_date_of_run()
dur = []
train_acc_tracking = []
val_acc_tracking = []
training_start_time = time.time()
if rank == 0 and args.track_memory:
mem_alloc_tracker = []
mem_reserved_tracker = []
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if args.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if args.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
print(f"completed save and stats zone...")
if args.save_model and curr_val_loss < best_val_loss:
# save
if rank == 0:
print(f"--> entering save model state")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
#print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
print(f"--> attempting to save model prefix {currEpoch}")
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
print(f"--> saving as model name {save_name}")
torch.save(cpu_state, save_name)
if curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
if rank==0:
print(f"-->>>> New Val Loss Record: {best_val_loss}")
dist.barrier()
cleanup()
下面内容 … … 待更新
参考资料: https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html

本文介绍了如何通过FSDP(FullyShardedDataParallel)框架在多GPU环境中对T5大模型进行分布式训练。首先,文章详细阐述了环境初始化、模型加载、数据分片和模型分片到GPU的过程。接着,展示了训练和验证的步骤,包括数据并行加载、模型并行加载、损失聚合以及全局梯度同步。最后,提到了跟踪内存分配和模型保存的细节。
2万+






