个人使用pytorch的时候需要用到多GPU运行,简要说明一下应用情景:
- 单
GPU不够用,你需要将模型存储在多个GPU上; - 当模型初始化后运行在多个GPU上,你要加载dict模型参数文件。
第一个情景,我们使用 nn.DataParallel 来解决,直接上例子:
# 使用nvidia-smi查看可用的设备
CUDA_DEVICE_1 = 0
CUDA_DEVICE_2 = 1
# 模型初始化
net = model()
# 模型进入DataParallel,device_ids指明设备号列表
net = nn.DataParallel(net, device_ids=

本文介绍了在PyTorch中如何处理多GPU模型的存储和加载参数文件时遇到的问题。在需要扩展模型到多个GPU时,可以使用DataParallel。加载state_dict时会出现Error(s) in loading state_dict,原因在于key前需加上module。解决方案是通过修改加载参数的代码,确保keys匹配。
最低0.47元/天 解锁文章
1464

被折叠的 条评论
为什么被折叠?



