Pytorch搭建CIFAR10的CNN卷积神经网络

本文档详细介绍了如何使用Pytorch构建和训练CIFAR10数据集的卷积神经网络(CNN)。内容包括数据预处理、网络结构、损失函数和优化器的选择,以及训练过程和精度验证。通过调整网络参数,可以逐步提高模型的准确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

源码参考https://github.com/zergtant/pytorch-handbook/blob/master/chapter1/4_cifar10_tutorial.ipynb
稍作修改

CIFAR10数据

CIFAR10是基本的图片数据库,共十个分类,训练集有50000张图片,测试集有10000张图片,图片均为32*32分辨率。Pytorch的torchvision可以很方便的下载使用CIFAR10的数据,代码如下:

import torch
import torchvision
import torchvision.transforms as transforms

#定义超参数
BATCH_SIZE = 4
EPOCH = 2

#torchvision模块载入CIFAR10数据集,并且通过transform归一化到[0,1]
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',train = True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE,
                                          shuffle = True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data',train = False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = BATCH_SIZE,
                                          shuffle = False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')                                          

代码注释

pytorch提供了很方便的接口下载常用数据库,包括MNIST,CIFAR10等,并且输出训练集以及测试集:

  1. torchvision.datasets.CIFAR10()直接下载所有数据,通过train=True/False可以确定赋给训练集或者测试集,数据为32323的[0,255]的RGB image图像;
  2. 用于训练的数据集通常有归一化需求,读取数据的时候可以直接通过transform=transform实现,一般来说transform = torchvision.transforms.ToTensor()可以使的torchvision将[0,255]输出为[0,1]的float RGB,本文中继续做了归一化normalization;
  3. transforms.Compose实现多个transform命令组合,本文中transforms.ToTensor()实现输出为[0,1],紧接着 transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))将[0,1]的RGB归一化为[-1,1]。需要注意的是原本以为第一个参数应该是(0,0,0)才是归一化到均值为0。但是通过transforms的源码发现:
    output[channel] = (input[channel] - mean[channel]) / std[channel]
    也就是说((0,1)-0.5)/0.5=(-1,1);
  4. torchvision.datasets.CIFAR10输出的实际已经是dataset了,可以直接用dataloader进行批加载,BATCH_SIZE是每批取的数据,本文每次计算迭代取四张图;
  5. torch.utils.data.DataLoader还有个需要注意的点是num_workers=2同时并行两个核,但是如果直接运行会报RuntimeError:
    An attempt has been made to start a new process before the
    current process has finished its bootstrapping phase.
    这是因为python如果要用到并行计算多进程必须在主程序中,需要if name == ‘main’:来运行主程序,具体可参见知乎的一片文章https://zhuanlan.zhihu.com/p/39542342;

显示图片

plt.imshow(trainset.data[86]) #trainset.data中储存了原始数据,并且是array格式
plt.show()

dataiter = iter(trainloader)
images, labels = dataiter.next()
images_comb = torchvision.utils.make_grid(images)
images_comb_unnor = (images_comb*0.5+0.5).numpy()
plt.imshow(np.transpose(images_comb_unnor, (1, 2, 0)))
plt.show()

Python的模块matplotlib是很方便的绘图:

  1. plt.imshow()支持numpy array,可以对(M,N,3)的RGB输出图像,RGB值可以是[0,1]的float,也可以是[0,255]的int;
  2. 有意思的是一致以为trainset经过transform出来就已经是[-1,1]的tensor,但实际上trainset.data中还是保留了原始的array[0.255](32323),plt.imshow()可是直接生成图片;
  3. trainset本身传递的是元组,分别是image和label,image为torch.tensor,[-1,1] (33232),如下:
trainset[0][1]
Out[13]: 6
trainset[0][0]
Out[14]: 
tensor([[[-0.5373, -0.6627, -0.6078,  ...,  0.2392,  0.1922,  0.1608],
         [-0.8745, -1.0000, -0.8588,  ..., -0.0353, -0.0667, -0.0431],
         [-0.8039, -0.8745, -0.6157,  ..., -0.0745, -0.0588, -0.1451],
         ...,
         [ 0.6314,  0.5765,  0.5529,  ...,  0.2549, -0.5608, -0.5843],
         [ 0.4118,  0.3569,  0.4588,  ...,  0.4431, -0.2392, -0.3490],
         [ 0.3882,  0.3176,  0.4039,  ...,  0.6941,  0.1843, -0.0353]],
        [[-0.5137, -0.6392, -0.6235,  ...,  0.0353, -0.0196, -0.0275],
         [-0.8431, -1.0000, -0.9373,  ..., -0.3098, -0.3490, -0.3176],
         [-0.8118, -0.9451, -0.7882,  ..., -0.3412, -0.3412, -0.4275],
         ...,
         [ 0.3333,  0.2000,  0.2627,  ...,  0.0431, -0.7569, -0.7333],
         [ 0.0902, -0.0353,  0.1294,  ...,  0.1608, -0.5137, -0.5843],
         [ 0.1294,  0.0118,  0.1137,  ...,  0.4431, -0.0745, -0.2784]],
        [[-0.5059, -0.6471, -0.6627,  ..., -0.1529, -0.2000, -0.1922],
         [-0.8431, -1.0000, -1.0000,  ..., -0.5686, -0.6078, -0.5529],
         [-0.8353, -1.0000, -0.9373,  ..., -0.6078, -0.6078, -0.6706],
         ...,
         [-0.2471, -0.7333, -0.7961,  ..., -0.4510, -0.9451, -0.8431],
         [-0.2471, -0.6706, -0.7647,  ..., -0.2627
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值