CIFAR-10介绍
CIFAR-10是一个包含60000张图片的数据集。其中每张照片为32*32的彩色照片,每个像素点包括RGB三个数值
所有照片分属10个不同的类别,分别是airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck
其中五万张时训练集,一万张时测试集
下载
http://www.cs.toronto.edu/~kriz/cifar.html
数据集处理
transforms.Compose
可以理解为数据处理格式的定义
Compose的作用就是把对图像处理的方法集中起来,使用的比较多的方法有:
中心化:CenterCrop()
转换成张量:ToTensor()
正则化:Normalize()
等等
注意Compose是会按照顺序进行处理的
##torchvision.datasets
训练集的处理和MNIST差不太多(包括CIFAR-100)
import torchvision.datasets as dset
dset.CIFAR10(root, train = True, transform = None, target_transform = None, download = False)
dset.CIFAR100(root, train = True, transform = None, target_transform = None, download = False)
简单的卷积池化全连接
step1
因为先前弄过MNIST数据集,所以第一次先只改动数据集的处理部分,以及输入通道的问题,其余的先不动,代码如下
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
num_epoch = 10
BATCH_SIZE = 50
transform = transforms.Compose(
[transforms.ToTensor()
]
)
#MNIST数据集加载
train_dataset = datasets.CIFAR10(
root= '/home/cxm-irene/PyTorch/data/cifar10',
train= True,
transform= transform,
download= False
)
train_loader = data.DataLoader(
dataset= train_dataset,
batch_size= BATCH_SIZE,
shuffle= True
)
test_dataset = datasets.CIFAR10(
root= '/home/cxm-irene/PyTorch/data/cifar10',
train= False,
transform= transform,
download= False
)
test_loader = data.DataLoader(
dataset= test_dataset,
batch_size= BATCH_SIZE,
shuffle= True
)
#搭建网络
class Net_CIFAR10(torch.nn.Module):
def __init__(self):
super(Net_CIFAR10, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels= 3,
out_channels= 16,
kernel_size= 5,
stride= 1,
padding= 2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size= 2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels= 16,
out_channels= 32,
kernel_size= 5,
stride= 1,
padding=