1.4 使用torch包训练一个分类集

本文介绍了如何使用torchvision包加载和归一化CIFAR10数据集,定义卷积神经网络,设置损失函数,进行训练及在测试集上评估。内容包括数据预处理、模型构建、损失函数的选择、训练过程和测试结果分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.4 使用torch包训练一个分类集

其中包含torchvision包,它包含了处理一些基本图像数据集的方法。这些数据集包括 Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision 还包含了图像转换器, torchvision.datasets 和 torch.utils.data.DataLoader。

torchvision包不仅提供了巨大的便利,也避免了代码的重复。

这里是官方torchvision包介绍:
https://pytorch.org/docs/master/torchvision/index.html?highlight=torchvision#module-torchvision

这里使用CIFAR10数据集,它有如下10个类别 :‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是 3x32x32大小的,即,3颜色通道,32x32像素。

from IPython.display import Image
Image(filename = 'F:/jupyter notebook/CIFAR10.png', width=700, height=400)

在这里插入图片描述

训练一个图像分类器的步骤:

1.使用torchvision加载和归一化CIFAR10训练集和测试集;

2.定义一个卷积神经网络;

3.定义损失函数;

4.在训练集上训练网络;

5.在测试集上测试网络。

1. 读取和归一化 CIFAR10

使用torchvision可以非常容易地加载CIFAR10。

关于torchvision.transforms 的详细说明:

https://pytorch.org/docs/master/torchvision/transforms.html?highlight=torchvision#module-torchvision.transforms.functional

import torch
import torchvision
import torchvision.transforms as transforms

torchvision.transforms.Compose(list of transforms):将几种转换操作组合在一起

torchvision.transforms.ToTensor:Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

torchvision.transforms.Normalize(mean, std, inplace=False):n个通道的均值和标准差为(M1,...,Mn)(S1,..,Sn),利用均值和标准差归一化张量。i.e:input[channel] = (input[channel] - mean[channel]) / std[channel]

transform = transforms.Compose(
    [transforms.ToTensor(),
      transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False):

root (string):保存的根目录,如果download为true

train (bool, optional):如果为真,则从训练集创建数据集,否则从测试集创建数据集。

transform (callable, optional):对PIL图像进行指定的transform变换。

download (bool, optional) :如果为True,则从internet下载数据集并将其放在根目录中。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

dataset:从该数据集加载数据;

batch_size (int, optional) :每批次加载得样品数量,默认为1

shuffle (bool, optional) :设置为true,每次都会重新打乱顺序

num_workers (int, optional) :使用多少子进程加载数据,默认为0,意味着只有主程序加载数据

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, 
                                         download=True)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, 
                                         download=True)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值