pytorch加载分布式方式训练的模型

当使用PyTorch的load_state_dict加载分布式训练的模型时,若`strict=False`模型能加载但输出可能不正确。若`strict=True`,需移除'模块.'前缀以匹配权重。解决方案是创建新字典,将旧键转换。在加载后,模型预测结果与原始模型不符,可能原因包括未调用`model.eval()`、模型结构中存在dropout等。参考链接提供了详细排查步骤。

出现的问题

  1. 使用load_state_dict函数,如果strict=False,模型可以加载,但是模型的输出不对
  2. 如果strict=True,但是没有新建字典,模型无法加载

问题解决

  • 保存的模型使用分布式方式训练,如果加载后模型不用分布式,则需要修改模型的key.
  • 分布式模型的权值名字前面有module.,非分布式不包含。
    比如分布式模型的权值名称名称为module.blocks.0.norm1.weight ,非分布式模型权值名称名称blocks.0.norm1.weight
  • 需要重新建一个字典,用去掉module.的名字座位key,值不变

代码示例

chk_filename = os.path.join(args.checkpoint, args.resume if args.resume 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值