1. 为什么会有 'module.'
前缀
当模型在 分布式训练 或 多 GPU 模型 中使用 PyTorch 的 torch.nn.DataParallel
或 torch.nn.parallel.DistributedDataParallel
时,模型的权重会被自动加上 'module.'
前缀,保存时无法直接去掉
# 使用 DataParallel 包装模型
model = torch.nn.DataParallel(model)
# 保存模型权重
torch.save(model.state_dict(), "model.pth")
#键名类似如下
#module.block0.conv1.weight
#module.block0.conv1.bias
2. 为什么要去掉 'module.'
前缀?
单 GPU 或 CPU 环境 ,键名 不包含 module.
for key in list(ckpt.keys()):
if 'module.' in key:
ckpt[key.replace('module.', '')] = ckpt[key] # 去掉 'module.' 前缀
del ckpt[key] # 删除原始带有 'module.' 的键