Linux分布式训练之ddp踩坑记录
最近在用实验室服务器的时候发现用一张卡已经无法满足我的网络魔改了,于是乎还算系统的了解了以下关于Linux分布式训练的知识,主要分为DP(DataParallel)和DDP(DistributedDataParallel)两种,而DP比较简单,故本篇主要记录DDP的踩坑。
一方面希望给没用过分布式训练的同学一些启发,另一方面替自己留一些笔记。(个人能力有限,肯定有理解不足的地方,友友们轻喷,欢迎交流!)
分布式训练基础定义
分布式训练通常分为模型并行和数据并行。数据并行用于解决数据量太大导致一张显卡无法满足训练需求的问题,包括两种,也就是DP和DDP。
DP (nn.DataParallal)
DP是将一个batch的数据拆分到多张卡进行训练,通过这种并行计算方式解决了batch很大的问题。但他存在两个问题:
(1)单进程多线程:DP是单进程多线程,无法在多个机器工作,而且不能使用Apex进行混合精度训练。(暂时没用到还理解不够透彻)。同时它基于多线程的方式,虽然方便了信息的交换,但受困于GIL(全局解释器锁),会带来性能开销,不适用于计算密集型任务;
(2)存在效率问题,主卡性能和通信开销容易成为瓶颈,GPU利用率通常很低。
DDP (nn,parallel.DistributedDataParallel)
DDP是多进程的,每个GPU对应一个进程,适用于单机和多机情况,唯一不足是调试比较麻烦。
这里主要记录使用过程中遇到的一些bug。
(1)多个进程出现在0卡
如下图所示:在GPU0出现了两个进程:
最终发现是model在设置分布式参数之前进行了初始化,也就是
model = Model(config)
出现在了
dist.init_process_group(backend='nccl', init_method='env://')
前面。给出设置DDP参数以及网络初始化的代码如下:
# 先设置ddp参数
dist.init_process_group(backend='nccl', init_method='env://')
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
rank = int(os.environ.get("RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))
# os.system('nvidia-smi')
device = torch.device(f'cuda:{local_rank}')
print(f"[{datetime.now()}] Process RANK={rank}, LOCAL_RANK={local_rank}, WORLD_SIZE={world_size} is using device {device}")
# 初始化网络
model = Model(config)
create_log_dir(config)
model = model.to(device)
# os.environ["RANK"] = "0"
# os.environ["WORLD_SIZE"] = "2"
torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # 将普通的bn转换为sync_bn
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
当然,我们还参考了其他的解决方案,虽然对我们的bug没有奏效,但是或许可以给大家带来参考:
(1) 【DDP踩坑记录】在0卡上出现多个进程: https://blog.youkuaiyun.com/Ll7_ll/article/details/133362858
(2) Pytorch使用DDP加载模型时出现多进程在GPU0上占用过多显存的问题: https://blog.51cto.com/u_15786578/5667478
Sync_BN
此外,还需要注意,在进行ddp的时候需要使用同步BN。在分布式数据并行环境中,每个GPU只能访问本地的mini-batch数据,普通的batchnorm层只会在本地mini-batch上计算均值和方差,这在分布式训练中可能导致不一致的统计量。而SyncBatchNorm可以使不同GPU的mini-batch数据将通过通信机制同步,并且计算得到全局均值和方差,用于归一化操作。
注意事项
1、必须在DDP模式下使用:SyncBatchNorm依赖进程间通信(如NCCL),需要在torch.nn.parallel.DistributedDataParallel中运行;
2、使用SyncBatchNorm会增加通信开销,可能导致训练速度下降;
3、必须在torch.nn.parallel.DistributedDataParallel包装模型之前调用:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])