PyTorch笔记9----------Cifar10图像分类

1.图像分类网络模型框架解读

  • 分类网络的基本结构
    • 数据加载模块:对训练数据加载
    • 数据重组:组合成网络需要的形式,例如预处理、增强、各种网络处理、loss函数计算
    • 优化器
  • 数据加载模块
    • 使用公开数据集:torchvision.datasets
    • 使用自定义数据集:torch.utils.data下的Dataset、DataLoader
  • 数据增强模块
    • 使用torchvision.transforms

2.Cifar10数据读取

Cifar10数据集下载链接:https://pan.baidu.com/s/1Dc6eQ54CCLFdCA2ORuFChg 提取码: 5279

下在好的数据集解压后的文件

创建两个文件夹dataTrain和dataTest用于存储数据集的图片

将数据集中的训练集图片和测试集图片存入自建的文件夹中,代码如下:

import os
import cv2
import numpy as np
import glob
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
label_name = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]
# train_list = glob.glob('cifar-10-batches-py/data_batch_*') #下载训练集图片时使用此行
train_list = glob.glob('cifar-10-batches-py/test_batch*')
# save_path = 'cifar-10-batches-py/dataTrain' #下载训练集图片时使用此行
save_path = 'cifar-10-batches-py/dataTest'
for l in train_list:
    l_dict = unpickle(l)
    for im_idx,im_data in enumerate(l_dict[b'data']):
        im_label = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]
        im_label_name = label_name[im_label]
        im_data = np.reshape(im_data,[3,32,32])
        im_data = np.transpose(im_data,(1,2,0))
        if not os.path.exists("{}/{}".format(save_path,im_label_name)):
            os.mkdir("{}/{}".format(save_path,im_label_name))
        cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)

3.自定数据集加载

from torchvision import transforms
from torch.utils.data import  DataLoader, Dataset
import os
from PIL  import Image
import numpy as np
import glob 
label_name = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]
label_dict = {}
for idx,name in enumerate(label_name):
    label_dict[name] = idx
def default_loader(path):
    return Image.open(path).convert('RGB')
#数据增强方法
train_transform = transforms.Compose([
    transforms.RandomResizedCrop((28,28)), #随机裁剪
    transforms.RandomHorizontalFlip(), #随机水平翻转
    transforms.RandomVerticalFlip(), #随机垂直翻转
    transforms.RandomRotation(90), #随机旋转
    transforms.RandomGrayscale(0.1), #随机灰度化
    transforms.ColorJitter(0.3,0.3,0.3,0.3), #随机颜色调整
    transforms.ToTensor() #转换为张量
])
class MyDataset(Dataset):
    def __init__(self,im_list,transform=None,loader=default_loader): 
        super(MyDataset,self).__init__()
        imgs = []
        for im_item in im_list:
            im_label_name = im_item.split("/")[-2]
            imgs.append([im_item,label_dict[im_label_name]])
        self.imgs = imgs
        self.transform = transform
        self.loader = loader
    def __getitem__(self,index):
        im_path,im_label = self.img[index]
        im_data = self.loader(im_path)
        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data,im_label
    def __len__(self):
        return len(self.imgs)
im_train_list = glob.glob("cifar-10-batches-py/dataTrain/*/*.png") #获取训练集图片路径列表
im_test_list = glob.glob("cifar-10-batches-py/dataTest/*/*.png") #获取测试集图片路径列表
train_dataset = MyDataset(im_train_list,transform=train_transform) #创建训练集数据集
test_dataset = MyDataset(im_test_list,transform=transforms.ToTensor()) #创建测试集数据集
train_data_loader = DataLoader(dataset=train_dataset,batch_size=6,shuffle=True,num_workers=4)#创建训练集数据加载器
test_data_loader = DataLoader(dataset=train_dataset,batch_size=6,shuffle=False,num_workers=4)#创建测试集数据加载器
print("num_of_train:",len(train_dataset))
print("num_of_test:",len(test_dataset))

代码运行结果:

num_of_train: 50000
num_of_test: 10000

 4.VGG网络搭建

  • 模型网络搭建
import torch
import torch.nn as nn
import torch.nn.functional as F
#定义vgg网络
class VGGbase(nn.Module):
    #定义vgg网络的初始化函数
    def __init__(self):
        super(VGGbase,self).__init__() #调用父类的初始化函数
        #定义第一个卷积层,图像大小:28*28
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.max_pooling1 = nn.MaxPool2d(kernel_size=2,stride=2) #定义最大池化层
        #定义第二个卷积层,图像大小:14*14
        self.conv2_1 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.conv2_2 = nn.Sequential(
            nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.max_pooling2 = nn.MaxPool2d(kernel_size=2,stride=2) #定义最大池化层
        #定义第三个卷积层,图像大小:7*7
        self.conv3_1 = nn.Sequential(
            nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值