【Pytorch】网络结构与预训练模型的网络结构是否一致的情况下:手动加载预训练模型、网络结构与模型参数匹配

本文详细介绍了如何在Pytorch中加载预训练模型,特别是当模型结构与预训练模型不完全一致时如何进行匹配。首先定义网络结构,如ResNet50,并根据任务调整全连接层。接着,读入自定义预训练模型,通过修改层名以匹配网络结构。在加载预训练模型参数时,若网络结构不一致(如ResNet50与ResNet101),则需要智能地选择并载入共享的层权重。最后,提供了两种不同的加载预训练模型的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

网络结构一致

这里模型使用Pytorch提供的ResNet50作为backbone,预训练模型的backbone也为ResNet50。

网络定义及读入预训练模型

利用pytorch提供的网络定义模型的backbone,设置不读取pytorch官方提供的ImagNet预训练模型,然后根据自己的任务设置backbone最后一层的参数。

import torch
import torchvision.models
from torch import nn
from collections import OrderedDict


model = torchvision.models.resnet50(pretrained=False)    # Pytorch提供的网络结构,不加载官方预训练模型(ImageNet)
fc_features = model.fc.in_features         # 提取fc层中固定的参数
model.fc = nn.Linear(fc_features, 400)     # 修改为自己项目的类别数量(也即预训练模型的类别数)

读入自己任务需要的预训练模型(自己提前下载)。

# 读入自己需要的预训练模型
pthfile = 'tf_model_zoo/tsn2d_kinetics400_rgb_r50_seg3_f1s1-b702e12f.pth'  # ResNet50,Kinetics400
pretrained_model = torch.load(pthfile)

网络结构的匹配与模型参数载入

通过输出网络模型或者调试查看网络属性,观察到所定义的模型与预训练模型的网络层名不一致,比如第一层名字分别为:conv1.weightbackbone.conv1.weight,那么使预训练模型的网络层名与定义模型的网络层名一致只需要去掉前缀就行了。

其中,如果需要全连接层的参数,使其也保持一致就可以了,同时也有两种载入预训练模型参数的方式,代码如下:

# 更改预训练模型的层名,使其匹配pytorch的ResNet50网络层名
new_state_dict = OrderedDict()
for k, v in pretrained_model['state_dict'].items():
    name = k[9:]   # remove `backbone.`
    if name == 'fc_cls.weight':            # 全连接层的参数也匹配
        name = 'fc.weight'
        new_state_dict[name] = v
    if name == 'fc_cls.bias':
        name = 'fc.bias'
        new_state_dict[name] = v
    new_state_dict[name] = v

model_dict = model.state_dict()    # 查看ResNet50 backbone的初始参数
model.load_state_dict(new_state_dict, strict=True)    # 载入预训练模型参数,严格匹配key的名字和数量

# 第二种预训练模型载入方式
# model_dict.update(new_state_dict)  # 更新ResNet50 backbone的初始参数
# model.load_state_dict(model_dict)  # 载入更新后的参数
print(model.state_dict())

网络结构不一致

这里模型使用Pytorch提供的ResNet101作为backbone,而预训练模型的backbone为ResNet50。

import torch
import torchvision.models
from torch import nn
from collections import OrderedDict


model = torchvision.models.resnet101(pretrained=False)    # Pytorch提供的网络结构,不加载官方预训练模型(ImageNet)
fc_features = model.fc.in_features         # 提取fc层中固定的参数
model.fc = nn.Linear(fc_features, 51)     # 修改为自己项目的类别数量(也即预训练模型的类别数)

与之前同理,首先观察网络层名有什么差异,然后更改层名使其一致,这里由于网络结构不同,所以只将ResNet101与ResNet50中都有的层名用预训练权重参数赋值。代码如下:

# 更改预训练模型的层名,使其匹配models.py中定义的base_model网络层名
new_state_dict = OrderedDict()
for k, v in pretrained_model['state_dict'].items():
    name = 'module.base_model.' + k[9:]  # 更改conv1.weight等层名的前缀
    new_state_dict[name] = v

pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict = model.state_dict()
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

Reference

  1. https://blog.youkuaiyun.com/chanbo8205/article/details/89923453
  2. https://blog.youkuaiyun.com/Charles5101/article/details/101028435
  3. https://blog.youkuaiyun.com/qq_36758461/article/details/112852050
  4. https://blog.youkuaiyun.com/qq_39852676/article/details/105611375
  5. https://blog.youkuaiyun.com/OTime77/article/details/105938268/
  6. https://zhuanlan.zhihu.com/p/84797438
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值