刚刚学完了PyTorch的并行训练写法,我来分享一份非常简单的PyTorch并行训练代码。希望没有学过的读者能够在接触尽可能少的新知识的前提下学会写并行训练。
完整代码 main.py:
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
def setup():
dist.init_process_group('nccl')
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer = nn.Linear(1, 1)
def forward(self, x):
return self.layer(x)
class MyDataset(Dataset):
def __init__(self):
super().__init__()
self.data = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index:index + 1]
ckpt_path = 'tmp.pth'
def main():
setup()
rank = dist.get_rank()
pid = os.getpid()
print(f'current pid: {pid}')
print(f'Current rank {rank}')
device_id = rank % torch.cuda.device_count()
dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)
model = ToyModel().to(device_id)
ddp_model = DistributedDataParallel(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
if rank == 0:
torch.save(ddp_model.state_dict(), ckpt_path)
dist.barrier()
&

本文提供了一份PyTorch并行训练的简单示例,讲解如何利用多GPU进行训练。通过开启多进程,使用DataParallel进行数据划分,以及DistributedDataParallel同步模型参数和梯度,实现模型的并行训练。详细代码和解析帮助理解PyTorch的并行训练机制。
最低0.47元/天 解锁文章
7万+

被折叠的 条评论
为什么被折叠?



