PyTorch官方教程之6:迁移学习

本文详细介绍了如何在PyTorch中应用迁移学习,包括加载数据、可视化、训练模型、微调和特征提取两个场景。通过预训练的ConvNet作为起点,对蚂蚁和蜜蜂分类任务进行训练,展示如何提升模型的泛化能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

通常的做法是在一个很大的数据集上进行预训练得到卷积网络ConvNet,然后将这个ConvNet的参数作为目标任务的初始化参数或者固定这些参数。

下面是利用PyTorch进行迁移学习步骤,要解决的问题是训练一个模型来对蚂蚁和蜜蜂进行分类。

from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import warnings

warnings.filterwarnings("ignore")

# matplotlib交互模式,生成动态图
plt.ion()

加载数据

在该数据集中,ants和bees各有约120张训练图片。每个类有75张验证图片。该数据集是ImageNet的一个非常小的子集。点击下载数据。
从零开始在如此小的数据集上进行训练通常是很难泛化的。由于我们使用迁移学习,模型的泛化能力会相当好。

# 数据增强
# 训练集数据扩充和归一化
# 在验证集上仅需要归一化
data_transforms = {
   
   
    'train': transforms.Compose([  # 所有的转换使用Compose链接在一起
        transforms.RandomResizedCrop(224),  # 随机裁剪224×224图像
        transforms.RandomHorizontalFlip(),  # 随机水平翻转图像,默认概率为50%。
        transforms.ToTensor(),  # ToTensor将数值范围为0-255的PIL Image转换为一个浮点张量,并通过将其除以255将其归一化为0-1。
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
        # 归一化使用一个3通道张量,每个通道被归一化为T = (T - mean)/(标准差)
    ]),
# 注意:对于验证和测试数据,不执行RandomResizedCrop、RandomRotation和RandomHorizontalFlip转换。
    'val': transforms.Compose([
        transforms.Resize(256),  # 按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩
        transforms.CenterCrop(224),  # 从中心裁剪224×224图像
        transforms.ToTensor(),  # 转化为张量
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 归一化
    ]),
}

# 数据加载
data_dir = 'data/hymenoptera_data'
# Datasets子类的实例
# datasets.ImageFolder 返回的dataset其属性self.classes:用一个 list 保存类别名称
image_datasets = {
   
   x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

# 工具函数torch.utils.data.DataLoader,让数据变成mini-batch,且在准备mini-batch的时候可以多线程并行处理
dataloaders = {
   
   x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=4,
                                              shuffle=True,  # 在每个 epoch 开始时,是否对数据进行打乱
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值