pytorch 训练过程acc_PyTorch分布式训练简明教程

本文介绍了PyTorch中进行多GPU训练的两种方法:nn.DataParallel和nn.DistributedDataParallel,重点讲解了DistributedDataParallel的原理和优势,以及如何使用它进行分布式训练。通过一个MNIST的例子,展示了如何转换为分布式训练并加入混合精度训练,帮助理解PyTorch的分布式训练机制。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

神经网络训练加速的最简单方法是使用GPU,对弈神经网络中常规操作(矩阵乘法和加法)GPU运算速度要倍超于CPU。随着模型或数据集越来越大,一个GPU很快就会变得不足。例如,BERT和GPT-2等大型语言模型是在数百个GPU上训练的。对于多GPU训练,需要一种在不同GPU之间对模型和数据进行切分和调度的方法。

PyTorch是非常流行的深度学习框架,它在主流框架中对于灵活性和易用性的平衡最好。Pytorch有两种方法可以在多个GPU上切分模型和数据:nn.DataParallel和nn.distributedataparallel。DataParallel更易于使用(只需简单包装单GPU模型)。然而,由于它使用一个进程来计算模型权重,然后在每个批处理期间将分发到每个GPU,因此通信很快成为一个瓶颈,GPU利用率通常很低。而且,nn.DataParallel要求所有的GPU都在同一个节点上(不支持分布式),而且不能使用Apex进行混合精度训练。nn.DataParallel和nn.distributedataparallel的主要差异可以总结为以下几点(译者注):DistributedDataParallel支持模型并行,而DataParallel并不支持,这意味如果模型太大单卡显存不足时只能使用前者;

DataParallel是单进程多线程的,只用于单机情况,而DistributedDataParallel是多进程的,适用于单机和多机情况,真正实现分布式训练;

DistributedDataParallel的训练更高效,因为每个进程都是独立的Python解释器,避免GIL问题,而且通信成本低其训练速度更快,基本上DataParallel已经被弃用;

必须要说明的是DistributedDataParallel中每个进程都有独立的优化器,执行自己的更新过程,但是梯度通过通信传递到每个进程,所有执行的内容是相同的;

总的来说,Pytorch文档是相当完备和清晰的,尤其是在1.0x版本后。但是关于DistributedDataParallel的介绍却较少,主要的文档有以下三个:

这篇教程将通过一个MNISI例子讲述如何使用PyTorch的分布式训练,这里将一段段代码进行解释,而且也包括任何使用apex进行混合精度训练。

DistributedDataParallel内部机制

DistributedDataParallel通过多进程在多个GPUs间复制模型,每个GPU都由一个进程控制(当然可以让每个进程控制多个GPU,但这显然比每个进程有一个GPU要慢;也可以多个进程在一个GPU上运行)。GPU可以都在同一个节点上,也可以分布在多个节点上。每个进程都执行相同的任务,并且每个进程都与所有其他进程通信。进程或者说GPU之间只传递梯度,这样网络通信就不再是瓶颈。

在训练过程中,每个进程从磁盘加载batch数据,并将它们传递到其GPU。每一个GPU都有自己的前向过程,然后梯度在各个GPUs间进行All-Reduce。每一层的梯度不依赖于前一层,所以梯度的All-Reduce和后向过程同时计算,以进一步缓解网络瓶颈。在后向过程的最后,每个节点都得到了平均梯度,这样模型参数保持同步。

这都要求多个进程(可能在多个节点上)同步并通信。Pytorch通过distributed.init_process_group函数来实现这一点。他需要知道进程0位置以便所有进程都可以同步,以及预期的进程总数。每个进程都需要知道进程总数及其在进程中的顺序,以及使用哪个GPU。通常将进程总数称为world_size.Pytorch提供了nn.utils.data.DistributedSampler来为各个进程切分数据,以保证训练数据不重叠。

实例讲解

这里通过一个MNIST实例来讲解,我们先将其改成分布式训练,然后增加混合精度训练。

普通单卡训练

首先,导入所需要的库:

import os

from datetime import datetime

import argparse

import torch.multiprocessing as mp

import torchvision

import torchvision.transforms as transforms

import torch

import torch.nn as nn

import torch.distributed as dist

from apex.parallel import DistributedDataParallel as DDP

from apex import amp

然后我们定义一个简单的CNN模型处理MNIST数据:

class ConvNet(nn.Module):

def __init__(self, num_classes=10):

super(ConvNet, self).__init__()

self.layer1 = nn.Sequential(

nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值