if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
多GPU环境计算loss
解决方法:
将多卡得到的loss进行mean,求平均:
loss_avg.backward() -----> loss_avg.mean().backward()
多GPU环境下对模型的保存:
#万能的保存方法,如果你的预测函数不会依赖你的模型类定义。
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), config.save_path)
#依赖你的模型定义,这么保存只保存了你模型的参数,模型的结构没有保存,所以尽量用上面的保存方法。
if isinstance(model, torch.nn.DataParallel):
torch.save(model.state_dict(), config.save_path)
多GPU环境下对模型的加载去预测:
model = torch.load(‘path/to/model’)
if isinstance(model,torch.nn.DataParallel):
model = model.module
#下面就可以正常使用了
model.eval()