【深度学习】CV_基于CNN的图像分类模型_代码逐行注释解析


前言

本文为个人案例学习笔记,是深度学习CV领域的入门上手项目,通过blog梳理总结,形成整体建模思路、熟悉典型网络架构的搭建、掌握常用代码使用。


提示:以下是本篇文章正文内容,下面案例可供参考

一、任务描述和关键环节

花卉图片识别分类。根据训练集“样本-标签”的分类(本案例采用文件夹分类方法),训练能够识别并对花卉品种进行分类的网络模型。

(一)数据预处理

  • 数据增强
    torchvision中transforms模块自带功能。数据较少,为了有更好的效果,需数据增强操作,如翻转、平移、裁剪等,以增强样本多样性。
  • 数据预处理
    torchvision中transforms已封装,直接调用即可。
  • DataLoader
    DataLoader模块直接读取batch数据。

(二)网络模块设置

  • 加载预训练模型
    torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习。
  • 修改网络模型
    别人训练好的模型并不一定满足个人任务,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
  • 训练策略选择
    训练时可以全部从头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的

(三)网络模型保存与测试

  • 保存
    模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
  • 读取
    读取模型进行实际测试

二、具体步骤

(一)任务分析与图像数据处理

1.导包

代码如下:

# ================导入所需模块================ #
import os               #导入标准库,利用其中API
import matplotlib.pyplot as plt
%matplotlib inline      #jupyter魔法命令,绘图直接嵌入在notebook行内
import numpy as np
import torch
from torch import nn    #导入torch中神经网络工具包
import torch.optim as optim    #pytorch中优化器模块,可以优化模型参数
import torchvision      #导入torchvision包
                        #pip install torchvision
from torchvision import transforms, models, datasets  
                        #tansforms包有内置数据增强策略,models封装好的神经网络模型,比如resnet,dataset数据目录结构
                        #https://pytorch.org/docs/stable/torchvision/index.html
import imageio          #python处理图像、视频的模块,读取、写入、转换格式等。
import time
import warnings
warnings.filterwarnings("ignore")    #忽略版本造成的告警
import random           #随机数模块,生成随机数。如:random.randint(0, 10)0-10之间生成1个随机数
import sys              #导入 Python 标准库中的 sys 模块,可以使用该模块获取版本信息、命令行参数等。如:print(sys.version)
import copy             #该模块用于使用复制副本功能,用其创建副本而不用引用原对象。
import json             #导入JSON数据格式的处理模块
from PIL import Image   #导入 Python Imaging Library(PIL)中的image 模块。PIL 是一个用于处理图像的 Python 库,image 模块包含了 PIL 中的图像处理功能
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"    #防止anaconda内核挂掉的指令,别问,玄学!

2.数据读取与预处理

2.1 数据读取
  • 方法1:通过文件夹做数据集分类,文件夹名为label标签
  • 方法2:用txt写成标注文件。本案例数据少,需数据增强
  • 本案例数据集保存在conda目录下,名称为flower_data,内含train和valid2个文件夹,作为训练和验证集
data_dir = './flower_data/'        #指定数据读取目录
train_dir = data_dir + '/train'    #训练集
valid_dir = data_dir + '/valid'    #验证集
2.2 数据预处理
(1)制作数据源
  • data_transforms中指定了所有图像预处理操作
  • ImageFolder假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字
  • 几种图像数据常用的数据增强操作,图像亮度、对比度、灰度等操作不常用,用在特殊场景下图像识别任务。
# 字典结构,2个key分别train和valid
data_transforms = {
   
    'train': 
        transforms.Compose([            #.compose指按顺序操作,以([])首尾括起来
        transforms.Resize([96, 96]),    #图像调整为相同大小,会丢失部分信息。一般分类任务多用正方形,size根据实际图像情况设定,不能丢失太多。
                                        #数据增强:如图像的翻转、旋转、平移等操作,由1张变化,得到多张图,数据更丰富。
                                        #一般操作有:旋转、裁剪、水平垂直翻转。
        transforms.RandomRotation(45),  #数据增强:随机旋转,-4545度之间随机选角度。
        transforms.CenterCrop(64),      #数据增强:从中心开始裁剪。理解:原始图像上随机选区域裁剪,增加数据多样性,残缺不全图像。
        transforms.RandomHorizontalFlip(p=0.5),    #数据增强:随机水平翻转 选择一个概率,执行该操作的概率
        transforms.RandomVerticalFlip(p=0.5),      #数据增强:随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),    #数据增强:参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),       #数据增强:概率转换成灰度图,3通道就是R=G=B。RGB转为RRR或BBB或BBB,一般不做。
        transforms.ToTensor(),  #数据转换为张量
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  #标准化操作:均值,标准差,可以用imagnet大数据集的均值和标准差,括号3个值对应RGB三通道
    ]),
    'valid': 
        #验证集不再需要做数据增强,用实际图像做验证即可。
        transforms.Compose([
        transforms.Resize([64, 64]),  #保证验证和训练图像大小一样,训练中中心裁剪后,输出64*64的
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  #须用于训练集中一样的均值和标准差,不能用先验知识做。
    ]),
}
(2)将预处理的数据指定好
batch_size = 128  #batch较大,因为输入图像小,64*64的,所以可以适当大batch

#本案例中,数据按分类存在文件夹中,文件夹名可作为标签,用于读取数据,不再通过datasets dataloader,通过文件夹形式读取数据,下文命令datasets.ImageFolder

#定义datasets: 2个文件夹train和valid,作为key值.指定好数据集和预处理操作,方法是ImageFolder
image_datasets = {
   x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
#定义dataloaders: 方法orch.utils...(参数datasets,batch,shuffle)
dataloaders = {
   x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {
   x: len(image_datasets[x]) for x in ['train', 'valid']}    #可有可无,用来算准确率
class_names = image_datasets['train'].classes    #预测顺序索引位置对应类的名称

#Tips:数据集flower_data放在notebook同一级文件夹内。
  • 打印上述image_datasets、class_names、dataloaders、dataset_sizes查看。
  • image_datasets,展示datasets数据集相关参数和相应数据增强操作。
  • class_names:最终预测顺序,本案例102个文件夹,102种类的花,最后图片分类任务会有102个特征值,而class展示出的特征不是1234顺序,而是1 10 100 102 2 20,因此索引位置不是按照1234。
  • 利用jupyter分块执行的特点,打印类别名查看,按照首位数字索引排列,如下:
class_names
---
'1',
 '10',
 '100',
 '101',
 '102',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '2',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '3',
 '30',
 '31',
 ...]
  • 打印查看图像数据集的相关参数以及数据增强操作有哪些?
image_datasets
---
{
   'train': Dataset ImageFolder
     Number of datapoints: 6552
     Root location: ./flower_data/train
     StandardTransform
 Transform: Compose(
                Resize(size=[96, 96], interpolation=bilinear, max_size=None, antialias=None)
                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
                CenterCrop(size=(64, 64))    #中心裁剪
                RandomHorizontalFlip(p=0.5)  #随机水平
                RandomVerticalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 'valid': Dataset ImageFolder
     Number of datapoints: 818
     Root location: ./flower_data/valid
     StandardTransform
 Transform: Compose(
                Resize(size=[64, 64], interpolation=bilinear, max_size=None, antialias=None)
                ToTensor()  
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            )}
  • 查看数据集
dataloaders
---
{
   'train': <torch.utils.data.dataloader.DataLoader at 0x2095ee35100>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x2095ee354f0>}
  • 查看数据集大小
dataset_sizes
---
{
   'train': 6552, 'valid': 818}
2.3读取标签对应的实际名字
with open('cat_to_name.json', 'r') as f:  #读取Json文件的操作。文件内是字典结构,对应文件夹标签即数字值得真实花名。参数r以只读方式读取文件。
    cat_to_name = json.load(f)            #理解:加载f(读取json文件)这一操作。
  • 打印抓取的名字
cat_to_name
---
{
   '21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
...}

(二)模型设置

1.选用经典网络预训练模型

  • 加载models中提供的模型,并且直接用训练好的权重当做初始化参数
  • 第一次执行需要下载,可能会比较慢,我会提供给大家一份下载好的,可以直接放到相应路径
  • 迁移学习
    从0学效果并不好,在别人模型基础上微调。如:用resnet模型和参数,作为我们的初始化。即在经典预训练模型基础上做微调。
  • 迁移学习策略
    具体区分数据集大小,大、中、小,数据集小时,少量微调,经典模型中冻住大部分参数;中等时,前面冻住一部分权重参数;数据量大时,冻住小部分。
model_name = 'resnet'  #可选网络有 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],可尝试多少种经典网络结构试验效果。
                       #用别人训练好的特征来做。用别人网络结构方法:①torchvision调包来用;②复制粘贴典型网络结构、配置文件。

feature_extract = True #都用人家特征,咱先不更新。True时,冻住所有特征,只保留最后输出层。
  • 选择运行设备
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...') 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
---
CUDA is available!  Training on GPU ...

2.结合实际更新模型参数

  • 有时候用人家模型,就一直用了,更不更新可以自己定
  • 前面导入了torchvision模块中的models包,封装了典型网络模型,此处直接调用resnet,打印可查看网络内部结构
  • 18层Resnet网络结构10层,其中4层隐层layer中各包含2层Basic Block,最后2层平均池化、全连接层,但模型中输出分类是1000分类,需改动参数符合本案例。
model_ft = models.resnet18()  #torchvision封装了典型的包(模型),选用18层resnet做,18层的能快点,条件好点的也可以选152
model_ft
---
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值