出现的问题
- 使用
load_state_dict函数,如果strict=False,模型可以加载,但是模型的输出不对 - 如果
strict=True,但是没有新建字典,模型无法加载
问题解决
- 保存的模型使用分布式方式训练,如果加载后模型不用分布式,则需要修改模型的key.
- 分布式模型的权值名字前面有
module.,非分布式不包含。
比如分布式模型的权值名称名称为module.blocks.0.norm1.weight,非分布式模型权值名称名称blocks.0.norm1.weight - 需要重新建一个字典,用去掉
module.的名字座位key,值不变
代码示例
chk_filename = os.path.join(args.checkpoint, args.resume if args.resume

当使用PyTorch的load_state_dict加载分布式训练的模型时,若`strict=False`模型能加载但输出可能不正确。若`strict=True`,需移除'模块.'前缀以匹配权重。解决方案是创建新字典,将旧键转换。在加载后,模型预测结果与原始模型不符,可能原因包括未调用`model.eval()`、模型结构中存在dropout等。参考链接提供了详细排查步骤。
最低0.47元/天 解锁文章
1873

被折叠的 条评论
为什么被折叠?



