在加载别人已经训练好的网络参数做测试时遇到以下问题
图(1)
使用net.load_state_dict() 加载参数,遇到Missing key(s)的报错,必然是定义的网络结构与加载的网络结构不同。
如上图所示,待加载的网络参数对比定义的网络参数Missing和Unexpected一 一可以对应。从中得出,待加载的网络和定义的网络是可以对应上的,只是网络命名不一样。对于这种问题有两种方法解决,要么修改待加载的网络名,要么修改定义是的网络名。本文采用第二种。
首先打印一下定义的网络结构
print(net)

图(2)
再打印一下加载的网络参数关键字
state_dict=torch.load(load_path)
print(state_dict.keys())

图(3)
在保存网络参数时是不会保存归一化层和激活函数层的,图(3)关键字中 model.10.conv_block.1.weight 中model代表整个模型,10代表第十层,conv_block代表conv_block模块,1代表该模块的第一层,weight代表参数,与图(2)中的网络一 一对应。
再来看图(1)中的错误,图(2)中的模型第10层中的conv_block模块中第6层是卷积层,而在图(3)的网络参数中第五层是卷积层参数。所以只要将卷积层的参数、偏差一 一对应即可。
在这里我是将所有的dropout层都去掉,因为测试时dropout本身就用不到。当然也可以调整一下位置。
深度学习 加载网络参数遇到 Missing key(s)
最新推荐文章于 2023-11-29 09:52:38 发布
在PyTorch中加载预训练模型参数时遇到MissingKey错误,通常是因为网络结构定义与加载的参数不匹配。文章提到,可以通过比较模型结构和参数键来对应不同的命名。作者选择修改定义的网络结构以适应加载的参数,特别地,由于保存的参数不包含归一化层和激活函数层,以及dropout层在测试时不需要,因此作者去掉了dropout层以使参数对应。
3778

被折叠的 条评论
为什么被折叠?



