1.4 使用torch包训练一个分类集
其中包含torchvision包,它包含了处理一些基本图像数据集的方法。这些数据集包括 Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision 还包含了图像转换器, torchvision.datasets 和 torch.utils.data.DataLoader。
torchvision包不仅提供了巨大的便利,也避免了代码的重复。
这里是官方torchvision包介绍:
https://pytorch.org/docs/master/torchvision/index.html?highlight=torchvision#module-torchvision
这里使用CIFAR10数据集,它有如下10个类别 :‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是 3x32x32大小的,即,3颜色通道,32x32像素。
from IPython.display import Image
Image(filename = 'F:/jupyter notebook/CIFAR10.png', width=700, height=400)
训练一个图像分类器的步骤:
1.使用torchvision加载和归一化CIFAR10训练集和测试集;
2.定义一个卷积神经网络;
3.定义损失函数;
4.在训练集上训练网络;
5.在测试集上测试网络。
1. 读取和归一化 CIFAR10
使用torchvision可以非常容易地加载CIFAR10。
关于torchvision.transforms 的详细说明:
import torch
import torchvision
import torchvision.transforms as transforms
torchvision.transforms.Compose(list of transforms)
:将几种转换操作组合在一起
torchvision.transforms.ToTensor
:Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
torchvision.transforms.Normalize(mean, std, inplace=False)
:n个通道的均值和标准差为(M1,...,Mn)
、(S1,..,Sn)
,利用均值和标准差归一化张量。i.e
:input[channel] = (input[channel] - mean[channel]) / std[channel]
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
:
root (string):保存的根目录,如果download为true
train (bool, optional):如果为真,则从训练集创建数据集,否则从测试集创建数据集。
transform (callable, optional):对PIL图像进行指定的transform变换。
download (bool, optional) :如果为True,则从internet下载数据集并将其放在根目录中。
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
dataset:从该数据集加载数据;
batch_size (int, optional) :每批次加载得样品数量,默认为1
shuffle (bool, optional) :设置为true,每次都会重新打乱顺序
num_workers (int, optional) :使用多少子进程加载数据,默认为0,意味着只有主程序加载数据
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform,
download=True)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform,
download=True)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False