pytorch学习007- -预训练中的权重加载(完全导入,部分导入)

本文详细介绍了PyTorch中不同情况下模型权重的导入方法,包括模型结构完全对应、部分对应及集合关系上的对应等情况,并提供了具体示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

更新

2022.04.12更新
导入权重的用法相当普遍,但是可以导入吗?导入有什么影响?
首先一定是可以导入的,但是导入之后是否有效果?那应该分以下情况讨论。

  • 网络模型完全对应:这种情况可以导入,而且微调效果更好
  • 网络模型不完全对应(小心这种情况)
    • 只是输出层有部分变化,可以导入
    • 中间层有变化,不建议导入

问题

  1. 预训练后的权重如何导入另一个网络模型?
  2. 预训练对应的网络模型A与未训练的网络模型结构B不对应?
    2.1 两个网络模型A和B只有部分对应
    2.2 集合关系上A属于B
    2.3 集合关系上B属于A

方案

PyTorch文档

  • torch.nn.modules.module.Module def load_state_dict(self,
    state_dict: Dict[str, Tensor] | OrderedDict[str, Tensor],
    strict: bool = …) -> None
  • 说明:将 state_dict 中的参数和缓冲区复制到此模块及其后代中。
    • 如果 strict 为 True,则 state_dict 的键必须与此模块的torch.nn.Module.state_dict 函数返回的键完全匹配
  • 参数
    state_dict – 包含参数和持久缓冲区的字典。
    strict – 是否严格强制:
    • attr:state_dict 中的键与该模块的 :meth:~torch.nn.Module.state_dict 函数返回的键匹配。 默认值:“真”
  • 返回值:
    • missing_keys 是包含缺失键的 str 列表
    • unexpected_keys 是包含意外键的 str 列表

模型对应,完全导入

# demo1 完全加载权重
model = NET1()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重
model.load_state_dict(weights)

模型不完全对应

此一种情况经常出现在要修改预训练网络模型中某些层时,可能增加若干层,可能减少若干层,或上述两种情况皆有。

只有部分对应

在这里插入图片描述
两个模型中有部分是对应的,此种情况建议使用PyTorch中的load_state_dict所提供的参数:strict
将strict设置为False,可以在两个模型不同的情况下,仅加载相同键值部分。(保证各层的名字相同)

# demo2
model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

A属于B

在这里插入图片描述
此种情况常见于,在网上download别人的预训练模型后,需要根据自己的任务,添加若干个层,而其他层保持不变。

# demo3
*****待测试

B属于A

在这里插入图片描述
此种情况常见于从网上download别人的预训练模型后,因为某些限制,需要对模型进行精简,只删除若干个层,其他层保持不变。

# demo4
*****待测试
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值