Pytorch从入门到放弃(2)——迁移学习(基于ResNet18的蜜蜂和蚂蚁分类)

本文介绍了如何利用预训练的ResNet18模型进行迁移学习,对蜜蜂和蚂蚁图像进行分类。通过调整模型的最后一层适应目标数据集,并采用两种微调策略:仅训练最后一层并冻结其余层,或所有层参与训练但最后一层使用更高学习率。实验结果显示了这两种方法的效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        实践中,受限于数据集规模的约束,我们很少从头开始端到端的训练一个神经网络。通常情况下,我们会选择在ImageNet数据集上预训练好的网络模型上进行适当的修改,使其适用于目标数据集。

        首先,修改网络模型的最后一个全连接层,使其适应于目标数据集,使用预训练的网络权重来初始化网络模型的权重,用自己的图像数据来微调训练网络。微调网络主要有以下两种做法:

1.只训练最后一个全连接层,冻结除最后一个全连接层外的所有层的权重。

2.所有网络层都参与训练,不过最后一个全连接层在训练时使用更大的学习率,通常最后一个全连接层的学习率是前面层学习率的10倍。

        下面基于迁移学习实现一个ResNet18来对蜜蜂和蚂蚁分类,点击这里下载数据集。蚂蚁和蜜蜂大约均有120幅训练图像。每个类别有75幅验证图像。

1.只训练最后一个全连接层

from __future__ import print_function, division

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import time
import os
import copy

# 是否使用gpu运算
use_gpu = torch.cuda.is_available()
# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        # 随机在图像上裁剪出224*224大小的图像
        transforms.RandomResizedCrop(224),
        # 将图像随机翻转
        transforms.RandomHorizontalFlip(),
        # 将图像数据,转换为网络训练所需的tensor向量
        transforms.ToTensor(),
        # 图像归一化处理
        # 个人理解,前面是3个通道的均值,后面是3个通道的方差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 读取数据
# 这种数据读取方法,需要有train和val两个文件夹,
# 每个文件夹下一类图像存在一个文件夹下
data_dir = '../data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

# 读取数据集大小
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# 数据类别
class_names = image_datasets['train'].classes

# 训练与验证网络(所有层都参加训练)
def train_model(model, criterion, optimizer, scheduler, num
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值