[技术干货-功能调试] ms报错"RuntimeError: Net parameters weight shape xxx is different from parameter_dict's xxx.."怎么办
1. 系统环境:
Hardware Environment(Ascend/GPU/CPU): ALL
Software Environment:
MindSpore version (source or binary): 1.6.0 & Earlier versions
Python version (e.g., Python 3.7.5): 3.7.6
OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu
GCC/Compiler version (if compiled from source):
复制
2. python代码样例
net = Net()
param_dict = load_checkpoint("net.ckpt")
load_param_into_net(net, param_dict)
复制
3. 报错信息:
RuntimeError: Net parameters weight shape ((751, 512)) is different from parameter_dict's ((31789, 512))复制
4. 原因分析
报错信息为:网络中参数的shape和ckpt中的shape不一致,无法正确加载。
加载checkpoint的流程为,首先读取checkpoint文件中的权重信息,把它们转化成dict,key值为parameter 的name,value为Parameter。
然后去遍历网络中的所有参数,依次在ckpt_dict中去查找,当查找到了同名的parameter后,就会进行加载操作。
如果这时发现网络中参数的shape与ckpt中参数的shape不一致,框架就不知道按照哪个shape为准了,就会报错。
5. 解决方法
遇到此问题的场景大多是加载ckpt的网络与保存ckpt的网络结构不一致。 可能是网络类型不一样,可能是网络的超参不一样导致的网络结构不一致。 此时,需要确定保存和加载时的网络结构是否一致。 可以使用network.parameters_dict方法获取网络中的参数; 使用param_dict = load_checkpoint("net.ckpt")方法获取checkpoint中的参数字典,再对比二者的差异。
6. 参考文档
checkpoint 用法可以参考: 保存与加载