文章目录
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
Pytorch的单机多GPU训练
1)多GPU训练介绍
当我们使用的模型过大,训练数据比较多的时候往往需要在多个GPU
上训练。使用多GPU
训练时有两种方式,一种叫ModelParallelism
,一种是DataParallelism
。
ModelParallelism
方式,是在模型比较大导致一张显卡放不下的时候,将模型拆分然后分别放到不同的显卡上,将同一份数据分别输入进行模型训练。这种对模型结构各模块之间有联系时很不友好,有可能都不支持拆分。因此,应用更广泛的是DataParallelism
的方式。
DataParallelism
方式,是将相同的模型拷贝到不同的显卡上,然后将数据平均划分后输入到相应显卡上进行计算,然后根据计算结果更新模型的参数。
DataParallelism
方式更新模型参数时,因为每个显卡上都有一个完整的模型,其可以单独根据一个显卡的运算结果更新参数,即异步更新,也可以将各个显卡的运算结果汇总后再根据总的运算结果一次性更新模型参数,即同步更新。因此,使用DataParallelism
模型参数的更新有两种选择方式,不过值得注意的是不同显卡上的模型参数是共享的,也就是虽然不同显卡上都有完整的模型,但模型参数用的是同一份,都是相同的。 所以在模型初始化的时候就要给不同显卡上的模型初始化相同的权重值。根据两种权重更新策略的区别,可以发现,对于单个显卡上batch_size
本身就比较大的情况,可以使用异步更新,这样不需要显卡之间运算同步,可以提升训练速度;而对于batch_size
比较小的情况,根据mini_batch
随机梯度下降算法的原理,最好选用同步更新的方式,保证学习效果。
图片引用自【分布式训练】单机多卡的正确打开方式(一):理论基础
![]()
参数同步更新![]()
参数异步更新
使用多GPU
训练时,还需要注意的是使用BatchNormalization
的情况,对于BN
层归一化时,是在单个显卡上计算,还是在不同的显卡之间做同步再计算,同样,对于batch_size
比较大时建议使用异步运算,小时使用同步计算以保证模型学习的效果。
2)pytorch中使用单机多GPU
训练
相对于tensorflow
来说,pytorch
中设置模型进行多GPU
训练的方式就显的简单多了。在这里只介绍现在pytorch
中使用最多的多GPU
训练方式即使用DistributedDataParallel
类。
DistributedDataParallel
(DDP)相关变量及含义
DDP
支持在多个机器中进行模型训练,其中每个机器被称之为节点Node
,每个机器上有可能有多个GPU
,为了不受GIL
的限制,DDP
会针对每个GPU
启动一个进程进行训练,每个进程在对应机器上的编号使用环境变量LOCAL_RANK
进行标识。
一次训练,在所有Node
上启动的训练进程总和使用WORLD_SIZE
来统计。而在分布在所有Node
的上某个进程在全局所有进程中的序号使用环境变量RANK
进行记录。
介绍到这DDP
的整体原理和使用的变量就很清楚了,

DDP
参考上图,是假设有3
台机器,每台机器上有2
个GPU
的情况。值的注意的是master_address
和master_port
上的参数,这两个参数是告诉其他进程主进程(RANK=0
的进程)的端口号和IP
地址,以便于其与主进程之间进行通信,包括数据交换,同步等。
下面几部分,就分别对pytorch
模型实现单机多GPU
训练要进行哪些设置分别进行介绍。
a)初始化
在编写多GPU
训练的代码时,需要先对环境进行初始化,需要调用init_process_group
来初始化默认的分布式进程组(default distributed process group
)和分布式包(distributed package
)。使用的是pytorch
的torch.distributed.init_process_group
方法。
该方法原型:
torch.distributed.init_process_group(backend=None, \
init_method=None, \
timeout=datetime.timedelta(seconds=1800), \
world_size=-1, \
rank=-1, \
store=None, \
group_name='', \
pg_options=None)
函数参数:
backend
: 参数类型为str or Backend
,根据pytorch
编译时的配置来选择,支持mpi/gloo/nccl/ucc
,这个后端指的是多GPU
之间进行通信的方式,根据不同类型的GPU
进行选择,对于NVIDIA
的GPU
一般选择nccl
,对于Intel
的GPU
一般选择ucc
。init_method
: 参数类型为str
,指定初始化方法,一般使用env://
,表示使用环境变量MASTER_ADDR
和MASTER_PORT
来初始化。和store
变量是互斥的。timeout
: 参数类型为datetime.timedelta
,指定初始化超时时间,如果超时则抛出异常。world_size
: 参数类型为int
,指定进程组的大小,如果为-1
,则使用环境变量WORLD_SIZE
来指定,定义store
变量时必须指定world_size
。rank
: 参数类型为int
,指定当前进程在进程组中的排位,如果为-1
,则使用环境变量RANK
来指定,定义store
变量时,必须指定rank
。store
: 参数类型为Store
,指定用于保存分布式训练状态的存储Key/Value
对象,用于交换连接/地址信息,所有的进程都能访问,和init_method
方法互斥。group_name
: 参数类型为str
,指定进程组的名字,这个变量已经是deprecated
了。pg_options
: 参数类型为ProcessGroupOptions
,指定进程组的其他选项,如allreduce_post_hook
等,目前仅对nccl
后端支持ProcessGroupNCCL.Options
选项。
使用torch.distributed.init_process_group
初始化进程组的两种方式:
- 指定
store/rank/world_size
- 指定
init_method
,明确给出进程间在哪通过哪种协议发现其他进程并通信,此时rank/world_size
是可选的
初始化后,进程组可以通过torch.distributed.get_world_size()
和torch.distributed.get_rank()
来获取进程组大小和当前进程在进程组中的排位。
所以最简单的初始化方式,只需要指定后端即可:
torch.distributed.init_process_group(backend='nccl')
每个进程的环境变量RANK
是在启动时由torchrun
命令行工具自动添加的,WORLD_SIZE
是在torchrun
启动时根据启动的进程数自动添加的。
b)数据准备
在pytorch
中,数据的准备是先实例化torch.utils.data.Dataset
的数据类,然后再将其放入数据加载器torch.utils.data.DataLoader
中,以控制加载数据的进程数num_worker
、采样器sampler
和batch_size
大小等。
在使用DistributedDataParallel
实现训练时,在数据加载器中上需要使用两个采样器sampler = DistributedSampler(data)
和batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
来指定数据采样器,这样可以保证每个进程每个batch
只处理属于自己的数据。
这里一起来看下DistributedSampler
和BatchSampler
。
DDP
模式就是将数据均分到多个GPU
上来优化算法,对于每个GPU
该如何从总的训练数据中采样属于自己用的数据,这就需要一个采样策略,这正是DistributedSampler
发挥的作用。

DistributedSampler
如上图,假设有11
个样本,GPU
的数量为2
,DistributedSampler
的作用先是把数据打散,然后均分到每个gpu
上,当数据不组时,会采用循环重复的策略来补满。
torch.utils.data.BatchSampler
则是指定每个batch
的样本数量,以及是否丢弃最后一个可能不足的batch
。当设置drop_last=True
时,会将最后不足一个batch
的数据丢弃。

BatchSampler
上面介绍的过程是对于一轮数据训练时数据加载器的工作过程,对整个训练过程,为了保证学习的效果,需要在每个epoch
设置采样器能重新打散数据,因此要在每一轮训练开始前调用DistributedSampler
的set_epoch
方法。
sampler = DistributedSampler(data)
batch_sampler = torch.utils.data.BatchSampler(
sampler, batch_size, drop_last=True)
dataloader = torch.utils.data.Dataloader(data_set, batch_sampler=train_batch_sampler)
for i in range(epoches):
sampler.set_epoch(epoch)
.