Pytorch预训练模型和修改——记录

本文介绍如何在PyTorch中加载预训练模型及自定义模型,详细讲解了模型参数的修改方法,特别是针对全连接层的调整,并探讨了如何通过冻结特定层来进行微调。

加载模型

一般从torchvision的models中加载常用模型,如alexnet、densenet、inception、resnet、squeezenet、vgg等常用网络结构,并提供预训练模型,调用方便。

from torchvision import models

resnet = models.resnet50(pretrain=True)
print(resnet)  # 打印网络结构

读取预训练模型

另一种是读取自己预训练模型,而不是使用官方自带。

import torch
resnet18 = models.resnet18(pretrained=False) #pretrained参数默认是False,为了代码清晰,最好还是加上参数赋值.

resnet18.load_state_dict(torch.load(path_params.pkl))

load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。

当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 :

pretrained_dict = {
   
   k: v for k, v in pretrained_dict.items() if k in model_dict} 

再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。

整体代码:

print resnet18 #打印的还是网络结构
 
# 注意: cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是错误的,这样cnn将是nonetype
 
pre_dict = resnet18.state_dict() #按键值对将模型参数加载到pre_dict
 
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值