加载预训练权重

# 加载预训练模型的权重
pretrained_dict = torch.load(model_path)  
# 创建一个新的字典pretrained_dict,只包含预训练模型中存在且维度与当前模型匹配的权重
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
# 更新当前模型的权重字典,只更新那些在预训练模型中找到匹配的权重
model_dict.update(pretrained_dict)
# 将更新后的权重字典加载到模型中
model.load_state_dict(model_dict)

注:model_dict是当前模型的权重字典,预训练权重存储在model_path指定的文件中。

model_dict.update(pretrained_dict)

  • model_dict 是当前模型的权重字典,它是通过调用 model.state_dict() 获取的。
  • pretrained_dict 是预训练模型的权重字典,它是通过调用 torch.load(model_path) 获取的。

update 方法的作用是将 pretrained_dict 中的键值对更新到 model_dict 中,但只有当 pretrained_dict 中的键(即参数名)在 model_dict 中存在时才会进行更新。这意味着,如果预训练模型中的某个参数在当前模型中不存在,那么这个参数不会被添加到 model_dict 中。这是一种保守的更新方式,确保不会覆盖或丢失当前模型中特有的参数。

具体来说,update 方法会遍历 pretrained_dict 中的每个键值对,如果键(参数名)在 model_dict 中存在,并且对应的值(参数张量)大小相同,那么 model_dict 中的这个键会被 pretrained_dict 中的值覆盖。如果键不存在或大小不匹配,则不做任何操作。

model.load_state_dict(model_dict)

  • model 是当前的PyTorch模型实例。
  • model_dict 是更新后的权重字典,包含了当前模型的权重以及从预训练模型中匹配并更新的权重。

load_state_dict 方法的作用是将 model_dict 中的权重加载到模型 model 中。这个方法会根据 model_dict 中的键(参数名)将对应的值(参数张量)赋值给模型的参数。

       具体来说,load_state_dict 方法会检查 model_dict 中的每个键值对,确保键与模型中的参数名匹配,并且值(张量)的大小与模型中的参数大小一致。如果匹配成功,它会将 model_dict 中的张量值赋给模型的相应参数。如果 model_dict 中有模型中不存在的参数,或者参数大小不匹配,PyTorch会抛出错误。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值