1.图像分类网络模型框架解读
- 分类网络的基本结构
- 数据加载模块:对训练数据加载
- 数据重组:组合成网络需要的形式,例如预处理、增强、各种网络处理、loss函数计算
- 优化器
- 数据加载模块
- 使用公开数据集:torchvision.datasets
- 使用自定义数据集:torch.utils.data下的Dataset、DataLoader
- 数据增强模块
- 使用torchvision.transforms
2.Cifar10数据读取
Cifar10数据集下载链接:https://pan.baidu.com/s/1Dc6eQ54CCLFdCA2ORuFChg 提取码: 5279
下在好的数据集解压后的文件

创建两个文件夹dataTrain和dataTest用于存储数据集的图片
将数据集中的训练集图片和测试集图片存入自建的文件夹中,代码如下:
import os
import cv2
import numpy as np
import glob
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
label_name = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'
]
# train_list = glob.glob('cifar-10-batches-py/data_batch_*') #下载训练集图片时使用此行
train_list = glob.glob('cifar-10-batches-py/test_batch*')
# save_path = 'cifar-10-batches-py/dataTrain' #下载训练集图片时使用此行
save_path = 'cifar-10-batches-py/dataTest'
for l in train_list:
l_dict = unpickle(l)
for im_idx,im_data in enumerate(l_dict[b'data']):
im_label = l_dict[b'labels'][im_idx]
im_name = l_dict[b'filenames'][im_idx]
im_label_name = label_name[im_label]
im_data = np.reshape(im_data,[3,32,32])
im_data = np.transpose(im_data,(1,2,0))
if not os.path.exists("{}/{}".format(save_path,im_label_name)):
os.mkdir("{}/{}".format(save_path,im_label_name))
cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)
3.自定数据集加载
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob
label_name = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'
]
label_dict = {}
for idx,name in enumerate(label_name):
label_dict[name] = idx
def default_loader(path):
return Image.open(path).convert('RGB')
#数据增强方法
train_transform = transforms.Compose([
transforms.RandomResizedCrop((28,28)), #随机裁剪
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.RandomVerticalFlip(), #随机垂直翻转
transforms.RandomRotation(90), #随机旋转
transforms.RandomGrayscale(0.1), #随机灰度化
transforms.ColorJitter(0.3,0.3,0.3,0.3), #随机颜色调整
transforms.ToTensor() #转换为张量
])
class MyDataset(Dataset):
def __init__(self,im_list,transform=None,loader=default_loader):
super(MyDataset,self).__init__()
imgs = []
for im_item in im_list:
im_label_name = im_item.split("/")[-2]
imgs.append([im_item,label_dict[im_label_name]])
self.imgs = imgs
self.transform = transform
self.loader = loader
def __getitem__(self,index):
im_path,im_label = self.img[index]
im_data = self.loader(im_path)
if self.transform is not None:
im_data = self.transform(im_data)
return im_data,im_label
def __len__(self):
return len(self.imgs)
im_train_list = glob.glob("cifar-10-batches-py/dataTrain/*/*.png") #获取训练集图片路径列表
im_test_list = glob.glob("cifar-10-batches-py/dataTest/*/*.png") #获取测试集图片路径列表
train_dataset = MyDataset(im_train_list,transform=train_transform) #创建训练集数据集
test_dataset = MyDataset(im_test_list,transform=transforms.ToTensor()) #创建测试集数据集
train_data_loader = DataLoader(dataset=train_dataset,batch_size=6,shuffle=True,num_workers=4)#创建训练集数据加载器
test_data_loader = DataLoader(dataset=train_dataset,batch_size=6,shuffle=False,num_workers=4)#创建测试集数据加载器
print("num_of_train:",len(train_dataset))
print("num_of_test:",len(test_dataset))
代码运行结果:
num_of_train: 50000
num_of_test: 10000
4.VGG网络搭建
- 模型网络搭建
import torch
import torch.nn as nn
import torch.nn.functional as F
#定义vgg网络
class VGGbase(nn.Module):
#定义vgg网络的初始化函数
def __init__(self):
super(VGGbase,self).__init__() #调用父类的初始化函数
#定义第一个卷积层,图像大小:28*28
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.max_pooling1 = nn.MaxPool2d(kernel_size=2,stride=2) #定义最大池化层
#定义第二个卷积层,图像大小:14*14
self.conv2_1 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv2_2 = nn.Sequential(
nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.max_pooling2 = nn.MaxPool2d(kernel_size=2,stride=2) #定义最大池化层
#定义第三个卷积层,图像大小:7*7
self.conv3_1 = nn.Sequential(
nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)

最低0.47元/天 解锁文章
5564

被折叠的 条评论
为什么被折叠?



