1、模型参数加载:如和确定 “model.load_state_dict(torch.load("weight_path"),strict=False)” 加载到了参数?(strict 表示是否需要参数结构和model完全一样)
load_state_dict 会返回两个参数 一个是miss,一个是unexcepted,第一个参数表示:model中没有得到参数加载的部分(例如 model有100个模块,参数中只有20个模块的参数,那么miss就是剩下80个没有得到参数加载的模块),第二个参数表示哪些参数是model不能加载的(参数中有个模块叫做 A,这个模块model中没有,那么model就不能加载这个模块的参数,这个A就是unexpected)
所以如果想要确定 model.load_state_dict 是否真的给模型赋予了有效的参数 可以检查 miss中的模块和 model中的模块是不是一样,一样就说明没有加载参数,不一样就说明加载到了参数。具体来说就是打印model中的所有模块 print( model.state_dict() )。 打印miss ,print( miss ),然后对比输出是否一样,如果一样说明 model 未加载到任何参数,如果不一样,说明加载到了部分参数。
2、如何保留模型中部分模块的参数:
mm_state_dict = OrderedDict()
state_dict = unet.state_dict()
for key in state_dict:
if "motion_module" in key:
mm_state_dict[key] = state_dict[key]
torch.save(mm_state_dict, “xxx.pth”)
3、参数预训练模型和目标模型都有但是参数对不上(例如形状不同)
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
src_state_dict = state_dict['net']
target_state_dict = model.state_dict()
skip_keys = []
for k in src_state_dict.keys():
if k not in target_state_dict:
continue
if src_state_dict[k].size() != target_state_dict[k].size():
skip_keys.append(k)
for k in skip_keys:
del src_state_dict[k]
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)
文章讲述了如何在Python深度学习中正确地加载模型参数,包括`load_state_dict`函数的使用、处理miss和unexpected的情况,以及在模型参数形状不同时的调整方法。还提到了如何保留部分模块参数并保存。
1203





