pytorch学习笔记之加载预训练模型

本文介绍了PyTorch中torchvision.models包的使用方法,详细讲述了四种模型加载方式:直接加载预训练模型、修改某一层、加载部分预训练模型及加载自定义模型,并附上了示例代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

同名微信公众号/知乎专栏: AI算法札记,欢迎关注交流

pytorch自发布以来,由于其便捷性,赢得了越来越多人的喜爱。

Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

本文,我们讲述的是models,且只谈模型的加载。models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models

官方文档:https://pytorch.org/docs/master/torchvision/models.html

我将加载的方法简单总结为以下四种:

1.直接加载预训练模型

import torchvision.models as models

resnet50 = models.resnet50(pretrained=True)

这样就导入了resnet50的预训练模型了。

如果只需要网络结构,不需要用预训练模型的参数来初始化,那么就是:

model =torchvision.models.resnet50(pretrained=False)

或者把resnet复制到自己的目录下,新建个model文件夹

可以参考下面的猫狗大战入门算法入门

https://github.com/JackwithWilshere/Kaggle-Dogs_vs_Cats_PyTorch

2.修改某一层

 以resnet为例,默认的是ImageNet的1000类,比如我们要做二分类,分类猫和狗

resnet.fc = nn.Linear(2048, 2) 

resnet 第一层卷积的卷积核是7,我们可能想改成5,那么可以通过以下方法修改:

#未经试验,修改需要有理论依据,计算featuremap维度使之匹配。
resnet.conv1 = nn.Conv2d(3, 64,kernel_size=5, stride=2, padding=3, bias=False)

3.加载部分预训练模型

对于具体的任务,很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

#加载model,model是自己定义好的模型
resnet50 = models.resnet50(pretrained=True) 
model =Net(...) 

#读取参数 
pretrained_dict =resnet50.state_dict() 
model_dict = model.state_dict() 

#将pretrained_dict里不属于model_dict的键剔除掉 
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict} 

# 更新现有的model_dict 
model_dict.update(pretrained_dict) 

# 加载我们真正需要的state_dict 
model.load_state_dict(model_dict)  

4. 加载自己的模型

其实这个是保存和恢复模型,比如我们训练好的模型保存,然后加载用于测试。

方法一(推荐):

第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

使用这种方法,我们需要自己导入模型的结构信息。

(1)保存

torch.save(model.state_dict(), PATH)

#example
torch.save(resnet50.state_dict(),'ckp/model.pth')    

(2)恢复

model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

#example
resnet=resnet50(pretrained=True)
resnet.load_state_dict(torch.load('ckp/model.pth'))

方法二:

使用这种方法,将会保存模型的参数和结构信息。

(1)保存

torch.save (model, PATH)

(2)恢复

model = torch.load(PATH)

参考资料:

1. https://zhuanlan.zhihu.com/p/25980324

2. http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

### PyTorch 学习笔记概述 李毅编写的《PyTorch学习笔记》是一份详尽的学习指南,旨在帮助读者掌握深度学习框架PyTorch的核心概念和技术。这份笔记不仅涵盖了基础理论知识,还提供了大量实践案例和代码实现。 #### 主要内容结构 1. **环境搭建** 安装配置PyTorch运行所需的软件环境,包括Python版本的选择、CUDA支持以及Anaconda的使用方法[^2]。 2. **张量操作** 解释了如何创建、转换和处理多维数组(即张量),这是构建神经网络模型的基础构件之一[^3]. 3. **自动求导机制** 描述了Autograd模块的工作原理及其在反向传播算法中的应用,使用户能够轻松定义复杂的计算图并高效训练模型[^4]. 4. **优化器与损失函数** 探讨了几种常用的梯度下降变体(SGD, Adam等)及相应的损失衡量标准(MSE Loss, CrossEntropyLoss等),这些组件对于调整权重参数至关重要[^5]. 5. **数据加载处理** 展示了Dataset类和DataLoader类的功能特性,它们可以简化大规模图像分类任务的数据读取流程;同时也介绍了常见的图片增强技术来扩充样本集规模[^6]. 6. **卷积神经网络(CNN)** 结合具体实例深入剖析CNN架构设计思路,如LeNet,VGG,resnet系列,并给出完整的项目源码供参考学习[^7]. 7. **循环神经网络(RNN/LSTM/GRU)** 阐述时间序列测场景下RNN家族成员的特点优势,通过手写字符识别实验验证其有效性[^8]. 8. **迁移学习实战演练** 利用预训练好的大型模型作为特征提取器,在新领域内快速建立高性能的应用程序,减少重复劳动成本的同时提高了泛化能力[^9]. 9. **分布式训练入门指导** 当面对超大数据集时,单机难以满足需求,此时可借助于torch.distributed包来进行集群式的协同工作模式探索[^10]. ```python import torch from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor()]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) for images, labels in train_loader: print(images.shape) break ```
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值