pytorch 加载使用部分预训练模型(pretrained model)

找了一些资料,对我这种菜鸟并不友好,把自己摸索的相对详细的过程记录一下。

pytorch 加载全部模型比较简单,直接使用如下代码:

net.load_state_dict(torch.load(pth_path))

现在只想使用 上述net(假设叫net-a)的中间的一部分模型,步骤如下:
1. 根据net-a的网络模型代码(如下),新建一个副本网络模型net-b代码。

    class Net-b(nn.Module):
        def __init__(self, **kwargs):
            super(Net-b, self).__init__()
            ... ... #保留net-a中 我们需要使用的中间卷积层等
        def forward():
        	... ... 

2. 根据Net-b网络模型文件,建立net-b

net-b = Net-b()  #别忘记传递必要的参数
net-b_dict = net-b.state_dict()

3. 将训练好的net-a权重 ‘复制’ 到net-b上

state_dict = torch.load(net-a_ckpt_path)	#加载预先训练好net-a的.pth文件
new_state_dict = OrderedDict()		#不是必要的【from collections import OrderedDict】 

new_state_dict = {k:v for k,v in state_dict.items() if k in net-b_dict}	#删除net-b不需要的键
net-b_dict.update(new_state_dict)	#更新参数
net-b.load_state_dict(net-b_dict)	#加载参数

检查一下是否复制成功:

for name, para in net-a.named_parameters():
    print(name, torch.max(para))
for name, para in net-b.named_parameters():
    print(name, torch.max(para))

打印部分结果如下,可见参数的名称和其中最大值是相同的,故net-b可以正常使用

#net-a 的参数和权重
b3.c3.body.0.bias tensor(0.0591, grad_fn=<MaxBackward1>)
c1.body.0.weight tensor(0.3648, grad_fn=<MaxBackward1>)
c1.body.0.bias tensor(0.0897, grad_fn=<MaxBackward1>)
#net-b 的参数和权重
b3.c3.body.0.bias tensor(0.0591, grad_fn=<MaxBackward1>)
c1.body.0.weight tensor(0.3648, grad_fn=<MaxBackward1>)
c1.body.0.bias tensor(0.0897, grad_fn=<MaxBackward1>)

除了这种方法外,还可以建立网络建构时,将中间使用块单独命名(net2),使用时直接调用Net.net2

### 如何在 PyTorch加载预训练模型 为了在 PyTorch加载预训练模型,通常有两种主要方式:一种是从官方提供的预训练权重文件中直接加载;另一种则是通过自定义路径加载本地保存的模型。下面分别介绍这两种方法并提供相应的代码实例。 #### 方法一:从官方资源加载预训练模型 许多流行的神经网络架构已经在 torchvision.models 或其他库中实现了,并附带了 ImageNet 数据集上的预训练参数。只需指定 `pretrained=True` 即可轻松获取这些预训练好的模型。 ```python import torch from torchvision import models model = models.resnet18(pretrained=True) # 加载 ResNet-18 预训练模型 ``` 此段代码展示了如何利用 torchvision 库中的 resnet18 函数来创建一个带有 ImageNet 上预训练权重的 ResNet-18 模型[^1]。 #### 方法二:从本地磁盘加载已保存的模型 当需要恢复之前训练过的特定版本的模型时,则可以采用这种方式。这涉及到先序列化(即保存)整个模型的状态字典到硬盘上,在后续运行期间再反序列化回来。 ```python # 假设 'checkpoint.pth' 是先前保存的一个检查点文件名 checkpoint_path = "path/to/checkpoint.pth" device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # 初始化相同结构的新模型对象 model = TheModelClass(*args, **kwargs) # 将模型移动至 GPU (如果可用的话) model.to(device) # 使用 map_location 参数确保即使是在不同设备间也能正确读取数据 state_dict = torch.load(checkpoint_path, map_location=device)[^2] # 载入状态字典更新当前模型参数 model.load_state_dict(state_dict) # 设置为评估模式 model.eval() ``` 上述代码片段说明了怎样使用 `torch.load()` 来处理可能存在于 CPU 和 GPU 不同环境之间的兼容性问题,同时也强调了设置好目标计算平台的重要性。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值