在 pytorch 多GPU训练下,存储 整个模型 ( 而不是model.state_dict() )后再调用模型可能会遇到下面的情况:
AttributeError: ‘DataParallel’ object has no attribute ‘xxxx’
解决的方法是:
model = torch.load('path/to/model')
if isinstance(model,torch.nn.DataParallel):
model = model.module
#下面就可以正常使用了
model.eval()
...
...
...

本文解决在PyTorch中使用多GPU训练并存储整个模型后,重新加载模型时出现的AttributeError问题。通过判断并转换DataParallel模型,确保模型能够正确加载并使用。
1万+

被折叠的 条评论
为什么被折叠?



