“load_state_dict self.class.name, “\n\t”.join(error_msgs))) RuntimeError: Error(s) in loading state_for ***
model = create_model(num_classes=num_classes).to(device)
model_weight_path = "./weights/best_model.pth"
checkpoint = torch.load(model_weight_path, map_location=device)
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict)
- 多半是训练保存的模型与预测加载的模型不一致所导致(个人因权重名称输入错误引起)
- 将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, strict=False)
load_state_dict方法参数的官方说明 strict 参数默认是true,是否严格要求state_dict中的键与该模块的键返回的键匹配。