问题描述
pytorch中如果使用DataParallel
进行多卡训练,那么保存的模型key值前面会多处modules.
;若测试的时候使用单GPU,模型载入就会出现问题。
解决办法
model_path = '/home/work/model/checkpoint/mobilenet-v2_100.pth'
device = torch.device("cuda")
net = mobilenet_v2().to(device)
checkpoint = torch.load(model_path)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
简而言之,就是重新创建一个OrderedDict,然后将它载入模型就行了。