当有多卡可用时,可以考虑用 pytorch DDP 用多卡同时训练。本篇记录一种使用 DDP 训练的代码模板,要点包括:
- 兼容 DDP 多卡和普通单卡训练;
- DDP 起始(
torch.distributed.init_process_group及相关)和结束(torch.distributed.destroy_process_group); DataLoader的sampler;- 用
torch.nn.SyncBatchNorm.convert_sync_batchnorm将各种 BatchNorm 层转成 SyncBatchNorm,多卡必做!同时兼用单卡(即转成 SyncBN 之后,若没用 DDP 打包模型,则其行为与普通 BN 一样),可以无脑加入代码模板。ChatGPT 说在cuda()或to(device)之前转,防止在旧版 pytorch 上出错。- 单卡不要转成 SyncBN,否则可能报错:
Default process group has not been initialized, please make sure to call init_process_group
- 单卡不要转成 SyncBN,否则可能报错:
- 用
torch.nn.parallel.DistributedDataParallel打包模型; - 用 rank 0 进程 save、load 用 DDP 打包后的模型;
- 用 rank 0 进程 validate。
Code
- 参考 [1]
import argparse, os, json, datetime, socket
import torch
import torch.nn as nn
from torchvision import transforms
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
def free_port():
"""find an available port
Ref: https://www.cnblogs.com/mayanan/p/15997892.html
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp:
tcp.bind(("", 0))
_, port = tcp.getsockname()
return port
def main(gpu_id, world_size, args):
# initialise DDP
if args.ddp:
args.gpu = gpu_id # 本进程用的本机 gpu id
args.rank = args.rank * world_size + gpu_id
print("rank:", args.rank, ", gpu id:", gpu_id)
dist.init_process_group(
backend=args.backend,
init_method=f"tcp://localhost:{args.port}",
rank=args.rank,
world_size=world_size,
# timeout=datetime.timedelta(minutes=10) # 自定义最大等待时间
)
torch.cuda.set_device(args.gpu)
# dist.barrier()
device = torch.device(f'cuda:{args.rank}')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# augmentations
train_trfm = transforms.Compose([
# ...augmentations for training...
])
val_trfm = transforms.Compose([
# ...augmentations for validation...
])
# datasets
train_ds = MyDataset(train_trfm)
val_ds = MyDataset(val_trfm)
# samplers & data loaders
if args.ddp:
# DDP 要在 sampler 处指定 shuffle、drop_last
train_sampler = dist.DistributedSampler(
train_ds, num_replicas=world_size, rank=args.rank, shuffle=True, drop_last=True)
val_sampler = dist.DistributedSampler(
val_ds, num_replicas=world_size, rank=args.rank, shuffle=False, drop_last=False)
# 而在 data loader 处**不**能指定 shuffle (即必须 shuffle=False)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=args.batch_size, num_workers=4, pin_memory=torch.cuda.is_available(), sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=32, num_workers=4, pin_memory=torch.cuda.is_available(), sampler=val_sampler)
else:
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=args.batch_size, num_workers=4, pin_memory=torch.cuda.is_available(), shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=32, num_workers=4, pin_memory=torch.cuda.is_available(), shuffle=True, drop_last=True)
model = MyModel()#.to(device)
if args.ddp:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(device)
# DDP 打包模型
if args.ddp:
model = DDP(model, device_ids=[args.gpu])
# 损失、优化器照常写
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
start_epoch = global_step = 0
# resume from checkpoint
if args.resume:
assert osp.isfile(args.resume), args.resume
if args.ddp:
ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(dist.get_rank()))
# Barrier before loading to ensure the file is completely written
dist.barrier()
# Broadcast the checkpoint from rank 0 to all other processes to ensure consistency
ckpt = dist.broadcast_object_list([ckpt], src=0)[0]
# DDP 模型要用 `model.module`
model.module.load_state_dict(ckpt['model'])
else:
ckpt = torch.load(args.resume)
model.load_state_dict(ckpt['model'])
# 其它照常读
optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = ckpt['epoch'] + 1
global_step = ckpt['global_step'] + 1
# training
for epoch in range(start_epoch, args.epoch):
print('\t', epoch, end='\r')
if args.ddp:
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
model.train()
for i, (x, y) in enumerate(train_loader):
print(i, end='\r')
pred = model(x.to(device))
loss = criterion(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_step += 1
# saving checkpoint & validating
# 只在 rank 0 进程操作,避免重复
if 0 == args.rank:
# epoch ckpt
sd = {
"epoch": epoch,
"global_step": global_step,
"optimizer": optimizer.state_dict(),
}
if args.ddp:
# DDP 要用 `model.module`
sd['model'] = model.module.state_dict()
else:
sd['model'] = model.state_dict()
torch.save(sd, os.path.join(args.log_path, f"checkpoint-{epoch}.pth"))
# Barrier after saving to ensure all processes wait until saving is complete
# dist.barrier()
# validation
model.eval()
for i, (x, y) in enumerate(val_loader):
with torch.no_grad():
pred = model(x.to(device))
acc = # ...metric calculation...
# finish training
if args.ddp:
# DDP ending
dist.destroy_process_group()
if '__main__' == __name__:
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
parser.add_argument('--lr', type=float, default=5e-4)
parser.add_argument('--epoch', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--log_path', type=str, default="log")
parser.add_argument('--resume', type=str, default="", help="checkpoint to resume")
# DDP
parser.add_argument('--ddp', action="store_true")
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--backend', type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
parser.add_argument('--port', type=int, default=10000)
args = parser.parse_args()
# save config
os.makedirs(args.log_path, exist_ok=True)
with open(os.path.join(args.log_path, "config.json"), "w") as f:
json.dump(args.__dict__, f, indent=1)
if args.ddp and torch.cuda.device_count() == 1:
print("Disabling DDP cuz only 1 GPU is available")
args.ddp = False
if args.ddp:
world_size = torch.cuda.device_count()
print("world size:", world_size)
args.port = free_port() # pick 1 available port
mp.set_start_method('spawn')
mp.spawn(
main,
args=(world_size, args),
nprocs=world_size,
join=True
)
else:
main(0, 1, args)
Test DDP Backend
PyTorch DDP 有三种自带的 backend:nccl、gloo、mpi,前两种比较常用。ChatGPT 给了一个测试本机支持哪种 backend 的程序:
import torch.distributed as dist
import os
def free_port():
"""(见前文)"""
def check_backend(backend):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(free_port())
try:
dist.init_process_group(backend, rank=0, world_size=1)
print(f"{backend} backend initialized successfully.")
dist.destroy_process_group()
except Exception as e:
print(f"Failed to initialize {backend} backend: {e}")
if __name__ == "__main__":
backends = ["nccl", "gloo", "mpi"]
for backend in backends:
check_backend(backend)
当有多卡可用时,可使用PyTorch DDP进行多卡训练。本文记录了使用DDP训练的代码模板,要点包括兼容多卡和单卡训练、DDP的起始与结束、打包模型、用特定进程保存和加载模型及验证等,还提及了测试本机支持的DDP backend的程序。
707

被折叠的 条评论
为什么被折叠?



