动⼿实战CIFAR-10图像分类问题,本次实验基于pytorch
Kaggle⽐赛的⽹⻚地址是https://www.kaggle.com/c/cifar-10 。
从网站上可以下载得到对应得数据集,比赛数据分为训练集和测试集。训练集包含 50,000 图片。测试集包含 300,000 图片。两个数据集中的图像格式均为PNG,高度和宽度均为32像素,并具有三个颜色通道(RGB)。图像涵盖10个类别:飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,船和卡车。
下载完训练数据集train.7z和测试数据集test.7z后需要解压缩。解压缩后,将训练数据集、测试数据集以及训练数据集标签分别存放在以下3个路径:
• …/data/cifar10/train/train/[1-50000].png;
• …/data/cifar10/test/test/[1-300000].png;
• …/data/cifar10/trainLabels.csv。
数据的预处理,归一化,图像增广处理,提高模型的过拟合能力
import torch
import torchvision
from torchvision import datasets, transforms
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), #先四周填充0,再把图像随机裁剪成32*32
transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #R,G,B每层的归一化用到的均值和方差
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dir = '/home/deeplearning/cifar10/data/train'
test_dir = '/home/deeplearning/cifar10/data/test'
trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
testset = torchvision.datasets.ImageFolder(root=test_dir, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'forg', 'horse', 'ship', 'truck']
采用的模型,ResNet-18网络结构:ResNet全名Residual Network残差网络
import torch.nn as nn
import torch.nn.functional