Pytorch预训练模型和修改——记录

本文介绍如何在PyTorch中加载预训练模型及自定义模型,详细讲解了模型参数的修改方法,特别是针对全连接层的调整,并探讨了如何通过冻结特定层来进行微调。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

加载模型

一般从torchvision的models中加载常用模型,如alexnet、densenet、inception、resnet、squeezenet、vgg等常用网络结构,并提供预训练模型,调用方便。

from torchvision import models

resnet = models.resnet50(pretrain=True)
print(resnet)  # 打印网络结构

读取预训练模型

另一种是读取自己预训练模型,而不是使用官方自带。

import torch
resnet18 = models.resnet18(pretrained=False) #pretrained参数默认是False,为了代码清晰,最好还是加上参数赋值.

resnet18.load_state_dict(torch.load(path_params.pkl))

load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。

当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 :

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。

整体代码:

print resnet18 #打印的还是网络结构
 
# 注意: cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是错误的,这样cnn将是nonetype
 
pre_dict = resnet18.state_dict() #按键值对将模型参数加载到pre_dict
 
print for k, v in pre_dict.items(): # 打印模型参数
 
for k, v in pre_dict.items():
  print k  #打印模型每层命名
 
# model是自己定义好的新网络模型,将pretrained_dict和model_dict中命名一致的层加入
 
# pretrained_dict(包括参数)。
 
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

模型参数修改

对模型特定层进行修改,一般直接调用对应层名并赋予新的层结构,常用是修改全连接层输出类别或特征数。

import torchvision.models as models 
#调用模型 
model = models.resnet50(pretrained=True) 
#提取fc层中固定的参数 
fc_features = model.fc.in_features 
#修改类别为9 
model.fc = nn.Linear(fc_features, 9)

训练特定层(冻结层)

对于自己定义的网络结构,需要选择特定层进行冻结时,往往需要用到 requires_grad函数来定义是否计算梯度。

count = 0
para_optim = []
for k in model.children():
  count += 1
  if count > 6:   # 6 should be changed properly
    for param in k.parameters():
      para_optim.append(param)
  else:
    for param in k.parameters():
      param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)#只对特定的层的参数进行优化更新,即选择特定的层进行finetune。

此代码实现了PyTorch中使用预训练的模型初始化网络的一部分参数,主要是:

  • 设置选择条件
  • 使用children函数得到子模型名,使用parameters得到参数
  • 使用requires_grad确定不计算梯度,冻结层

PyTorch的Module.modules()和Module.children()

在PyTorch中,所有的neural network module都是class torch.nn.Module的子类,在Modules中可以包含其它的Modules,以一种树状结构进行嵌套。当需要返回神经网络中的各个模块时,Module.modules()方法返回网络中所有模块的一个iterator,而Module.children()方法返回所有直接子模块的一个iterator。具体而言:

list ( nn.Sequential(nn.Linear(10, 20), nn.ReLU()).modules() )
Out[9]:
[Sequential (
(0): Linear (10 -> 20)
(1): ReLU ()
), Linear (10 -> 20), ReLU ()]

In [10]: list( nn.Sequential(nn.Linear(10, 20), nn.ReLU()) .children() )
Out[10]: [Linear (10 -> 20), ReLU ()]

另一个例子:

下载好的模型,可以用下面这段代码看一下模型参数,并且改一下模型。在vgg19.pth同级目录建立一个test.py。

import torch
import torch.nn as nn
import torchvision.models as models

vgg16 = models.vgg16(pretrained=False)

#打印出预训练模型的参数
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
print('vgg16:\n', vgg16) 

modified_features = nn.Sequential(*list(vgg16.features.children())[:-1])
# to relu5_3
print('modified_features:\n', modified_features )#打印修改后的模型参数

修改好之后features就可以拿去做Faster-RCNN提取特征用了

参考资料

### MOT20 数据集预训练模型下载与使用 对于希望利用MOT20数据集进行研究或开发工作的用户来说,获取合适的预训练模型至关重要。通常情况下,预训练模型可以从多个渠道获得。 #### 获取预训练模型 一种常见的方法是从官方或其他可信的研究机构发布的资源库中直接下载。例如,在一些公开的GitHub项目页面上,开发者会分享其基于特定版本YOLO或者其他框架训练好的权重文件[^2]。这些链接往往位于项目的README文档内,提供了详细的安装指导以及如何加载预训练参数的具体步骤。 另一种方式则是访问像Model Zoo这样的平台,这里汇集了大量的计算机视觉领域内的高质量预训练网络结构及其对应的权值矩阵。以DeepSort为例,虽然最初非专门为MOT20设计,但是由于该算法具有良好的泛化能力,因此也可以尝试将其应用于新的场景之中,借助迁移学习技术进一步优化性能表现[^1]。 #### 使用教程概览 当获得了适用于MOT20数据集的预训练模型之后,接下来就是配置环境运行测试程序的过程: - **准备依赖项**:确保本地环境中已经正确安装了Python解释器及相关必要的第三方库(如PyTorch、NumPy等),这一步骤可以通过pip工具轻松完成。 - **设置工作区**:创建一个新的虚拟环境用于隔离不同项目之间的潜在冲突;克隆包含所需脚本配置文件的Git仓库至本地磁盘位置。 - **调整超参设定**:根据实际需求修改默认的学习率、批次大小以及其他影响最终效果的关键因素。值得注意的是,如果打算继续微调现有模型,则应特别关注初始化策略的选择——即是否采用随机数填充还是沿用之前保存下来的最优解作为起点。 - **执行推理流程**:编写简单的命令行指令启动预测服务,指定输入源路径指向待处理的目标序列图像集合;输出结果将会被自动存储于预先定义好的目的地文件夹之下,形式可以是JSON格式或者是类似于`track.py`所生成的那种纯文本记录表单[^3]。 ```bash # 安装依赖包 pip install torch torchvision numpy opencv-python-headless # 克隆仓库 git clone https://github.com/example/repo.git # 进入项目目录 cd repo # 创建激活虚拟环境 (可选) python -m venv env source ./env/bin/activate # Linux/macOS .\env\Scripts\activate # Windows # 修改配置文件中的参数... # 启动推理进程 python track.py --input /path/to/input_videos --output /desired/output_directory ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值