分布式训练框架FSDP的大模型训练流程详解

本文介绍了如何通过FSDP(FullyShardedDataParallel)框架在多GPU环境中对T5大模型进行分布式训练。首先,文章详细阐述了环境初始化、模型加载、数据分片和模型分片到GPU的过程。接着,展示了训练和验证的步骤,包括数据并行加载、模型并行加载、损失聚合以及全局梯度同步。最后,提到了跟踪内存分配和模型保存的细节。
部署运行你感兴趣的模型镜像
大模型分布式训练的处理如下:
  1. 分布式环境初始化;
  2. 模型加载到CPU;
  3. 模型训练数据分布式加载
  4. 模型分片初始化;
  5. 模型分片加载到对应GPU中
  6. 模型分布式训练
  7. 模型保存

下面我们以T5模型为例, 详细说明一下如何通过FSDP分布式框架利用多GPU对大模型进行分布式训练。FSDP框架是一个可以进行数据并行和模型并行的分布式训练框架,但如果要实现大模型多机多卡的分布式训练,就需要结合RPC分布式计算框架来实现,这个会在后续的分布式训练文章中进行具体介绍。

1. 大模型环境初始化:

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

2. 大模型单机多卡模型训练:

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank) # 作用是将每个GPU上的loss 进行汇总,形成一个全局的loss值

    if sampler:
        sampler.set_epoch(epoch) # 设置数据的随机种子,确保各个进程之间获得的数据是不同的,避免重复训练和过拟合
    if rank==0: # rank=0 是主进程
        inner_pbar = tqdm.tqdm(range(len(train_loader)), colour="blue", desc="r0 Training Epoch")

    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank) # 将当前批次的数据分载到对应的GPU上

        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        
        loss = output["loss"]
        loss.backward()
        optimizer.step()

        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) # 将不同进程中同一位置张量,进行规约(reduction)操作,从而达到全局梯度同步更新
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]

    if rank == 0:
        inner_pbar.close()
        print(f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}")
    return train_accuracy

3. 大模型单机多卡模型验证:

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch")

    with torch.no_grad(): 
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) #将不同进程中同一位置张量,进行规约(reduction)操作,从而达到全局梯度同步更新
    val_loss = fsdp_loss[0] / fsdp_loss[1]

    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

4. 大模型单机多卡模型训练的数据并行加载和模型并行加载:

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base") # 如果模型过大,建议先加载到CPU上,以免内存溢出

    local_rank = int(os.environ['LOCAL_RANK']) # 获取当前进程的环境变量,即当前进程在整个分布式并行环境中的位置
    rank = int(os.environ['RANK']) # 获取当前进程在分布式训练中的全局rank, 可以确定进程在分布式训练中的位置和作用,从而协同其它进程来进行分布式训练
    world_size = int(os.environ['WORLD_SIZE']) # 获取当前分布式训练中的进程总数


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
    
    # 数据并行预处理,返回每个进程需要处理的数据的索引
    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) 
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) 

    setup() # 分布式环境初始化

    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False} # 配置数据加载器,提高数据加载效率

    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

下面内容 … … 待更新

参考资料: https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html

您可能感兴趣的与本文相关的镜像

Wan2.2-I2V-A14B

Wan2.2-I2V-A14B

图生视频
Wan2.2

Wan2.2是由通义万相开源高效文本到视频生成模型,是有​50亿参数的轻量级视频生成模型,专为快速内容创作优化。支持480P视频生成,具备优秀的时序连贯性和运动推理能力

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值