整体框架
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def train(local_rank, world_size):
dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
torch.cuda.set_device(local_rank)
dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler)
model = MyModel().to(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
sampler.set_epoch(epoch)
for batch in dataloader:
outputs = model(batch)
loss = compute_loss(outputs)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if local_rank == 0:
torch.save(model.module.state_dict(), "model.pth")
dist.destroy_process_group()
分布式命令