报错内容
Missing key(s) in state_dict: “init_conv.weight”, “init_conv.bias”, …… Unexpected key(s) in state_dict: “module.init_conv.weight”, “module.init_conv.bias”,……
一看这两个就差了一个module,为什么会读取不出来呢?
解决
在pytorch的加载模型中,model.load_state_dict(state_dict, strict=True)
属性state_dict复制参数到这个模块。
如果strict为True, state_dict的keys必须完全与这个模块的方法返回的keys相匹配。如果为False,就不需要保证匹配。
因此,调用load_state_dict源代码中只需要增加一个False:
model.load_state_dict(torch.load('./model/model_weights.pt'),False)