源码参考https://github.com/zergtant/pytorch-handbook/blob/master/chapter1/4_cifar10_tutorial.ipynb
稍作修改
CIFAR10数据
CIFAR10是基本的图片数据库,共十个分类,训练集有50000张图片,测试集有10000张图片,图片均为32*32分辨率。Pytorch的torchvision可以很方便的下载使用CIFAR10的数据,代码如下:
import torch
import torchvision
import torchvision.transforms as transforms
#定义超参数
BATCH_SIZE = 4
EPOCH = 2
#torchvision模块载入CIFAR10数据集,并且通过transform归一化到[0,1]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',train = True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE,
shuffle = True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data',train = False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = BATCH_SIZE,
shuffle = False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
代码注释
pytorch提供了很方便的接口下载常用数据库,包括MNIST,CIFAR10等,并且输出训练集以及测试集:
- torchvision.datasets.CIFAR10()直接下载所有数据,通过train=True/False可以确定赋给训练集或者测试集,数据为32323的[0,255]的RGB image图像;
- 用于训练的数据集通常有归一化需求,读取数据的时候可以直接通过transform=transform实现,一般来说transform = torchvision.transforms.ToTensor()可以使的torchvision将[0,255]输出为[0,1]的float RGB,本文中继续做了归一化normalization;
- transforms.Compose实现多个transform命令组合,本文中transforms.ToTensor()实现输出为[0,1],紧接着 transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))将[0,1]的RGB归一化为[-1,1]。需要注意的是原本以为第一个参数应该是(0,0,0)才是归一化到均值为0。但是通过transforms的源码发现:
output[channel] = (input[channel] - mean[channel]) / std[channel]
也就是说((0,1)-0.5)/0.5=(-1,1); - torchvision.datasets.CIFAR10输出的实际已经是dataset了,可以直接用dataloader进行批加载,BATCH_SIZE是每批取的数据,本文每次计算迭代取四张图;
- torch.utils.data.DataLoader还有个需要注意的点是num_workers=2同时并行两个核,但是如果直接运行会报RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
这是因为python如果要用到并行计算多进程必须在主程序中,需要if name == ‘main’:来运行主程序,具体可参见知乎的一片文章https://zhuanlan.zhihu.com/p/39542342;
显示图片
plt.imshow(trainset.data[86]) #trainset.data中储存了原始数据,并且是array格式
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
images_comb = torchvision.utils.make_grid(images)
images_comb_unnor = (images_comb*0.5+0.5).numpy()
plt.imshow(np.transpose(images_comb_unnor, (1, 2, 0)))
plt.show()
Python的模块matplotlib是很方便的绘图:
- plt.imshow()支持numpy array,可以对(M,N,3)的RGB输出图像,RGB值可以是[0,1]的float,也可以是[0,255]的int;
- 有意思的是一致以为trainset经过transform出来就已经是[-1,1]的tensor,但实际上trainset.data中还是保留了原始的array[0.255](32323),plt.imshow()可是直接生成图片;
- trainset本身传递的是元组,分别是image和label,image为torch.tensor,[-1,1] (33232),如下:
trainset[0][1]
Out[13]: 6
trainset[0][0]
Out[14]:
tensor([[[-0.5373, -0.6627, -0.6078, ..., 0.2392, 0.1922, 0.1608],
[-0.8745, -1.0000, -0.8588, ..., -0.0353, -0.0667, -0.0431],
[-0.8039, -0.8745, -0.6157, ..., -0.0745, -0.0588, -0.1451],
...,
[ 0.6314, 0.5765, 0.5529, ..., 0.2549, -0.5608, -0.5843],
[ 0.4118, 0.3569, 0.4588, ..., 0.4431, -0.2392, -0.3490],
[ 0.3882, 0.3176, 0.4039, ..., 0.6941, 0.1843, -0.0353]],
[[-0.5137, -0.6392, -0.6235, ..., 0.0353, -0.0196, -0.0275],
[-0.8431, -1.0000, -0.9373, ..., -0.3098, -0.3490, -0.3176],
[-0.8118, -0.9451, -0.7882, ..., -0.3412, -0.3412, -0.4275],
...,
[ 0.3333, 0.2000, 0.2627, ..., 0.0431, -0.7569, -0.7333],
[ 0.0902, -0.0353, 0.1294, ..., 0.1608, -0.5137, -0.5843],
[ 0.1294, 0.0118, 0.1137, ..., 0.4431, -0.0745, -0.2784]],
[[-0.5059, -0.6471, -0.6627, ..., -0.1529, -0.2000, -0.1922],
[-0.8431, -1.0000, -1.0000, ..., -0.5686, -0.6078, -0.5529],
[-0.8353, -1.0000, -0.9373, ..., -0.6078, -0.6078, -0.6706],
...,
[-0.2471, -0.7333, -0.7961, ..., -0.4510, -0.9451, -0.8431],
[-0.2471, -0.6706, -0.7647, ..., -0.2627