Pytorch 多GPU训练
PyTorch数据并行:
nn.DataParallel 一主机多GPU
DistributedParallel 多主机多GPU
net = torch.nn.DataParallel(model)
默认所有存在的显卡都会被使用
如果我们机子中有很多显卡(例如我们有5张显卡),但我们只想使用0、1、2号显卡
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
...
原创
2020-05-31 10:51:09 ·
207 阅读 ·
0 评论