【PyTorch】2.2 使用Pytorch构建一个分类器(使用真实数据集CIFAR10)

本文详细介绍使用PyTorch构建图像分类器的过程,包括下载CIFAR10数据集、定义卷积神经网络、训练及测试模型等关键步骤。

2.2 使用Pytorch构建一个分类器

学习目标

  • 了解分类器的任务和数据样式
  • 掌握如何使用Pytorch实现一个分类器

分类器任务和数据介绍

  • 构建一个将不同图像进行分类的神经网络分类器,对输入的图片进行判断并完成分类
  • 本案例采用CIFAR0数据集作为原始图片数据
CIFAR10数据集介绍
  • 数据集中每张图片的尺寸是33232,代表彩色3通道
  • CIFAR10数据集共有10种不同的分类,分别是:airplane automobile bird cat deer dog frog horse ship truck
训练分类器步骤
  • 1:使用torchvision下载CIFAR10数据集
  • 2:定义卷积神经网络
  • 3:定义损失函数
  • 4:在训练集上训练模型
  • 5:在测试集上测试模型
1:使用torchvision下载CIFAR10数据集
  • 导入torchvision包来辅助下载数据集
import torch
import torchvision
import torchvision.transforms as transforms
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
#下载数据集并对图片进行调整,因为torchvision数据集的输出是PILImage格式,数据域在[0,1]
#我们将其转换成标准数据域[-1,1]的张量格式
#transform 数据转换器
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
# 下载下来的数据放在trainset里面
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
# DataLoader数据迭代器 将数据封装成DataLoader
# num_workers:两个线程读取数据
# batch_size=4 批处理

testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
classes=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#输出
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
100.0%
Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified

在这里插入图片描述注意:

可能存在的两个问题:

  • windows系统下运行上述代码,并且出现报错信息“BrokenPipeError”时,可以尝试将torch.utils.DataLoader()中的num_workers设置为0
  • 下载CIFAR10数据集报错:urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certi
  • 尽量使用linux系统学习深度学习
展示若干训练集的图片

在此处遇到问题可以访问我之前写的一篇文章

import torch
import torchvision
import torchvision.transforms as transforms
import ssl
import torch.utils.data as Data
ssl._create_default_https_context = ssl._create_unverified_context
#下载数据集并对图片进行调整,因为torchvision数据集的输出是PILImage格式,数据域在[0,1]
#我们将其转换成标准数据域[-1,1]的张量格式
#transform 数据转换器
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
# 下载下来的数据放在trainset里面
trainloader=torch.utils.data.DataLoader(dataset=trainset,batch_size=4,shuffle=True,num_workers=2)
# DataLoader数据迭代器 将数据封装成DataLoader
# num_workers:两个线程读取数据
# batch_size=4 批处理

testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(dataset=testset,batch_size=4,shuffle=False,num_workers=2)
classes=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#展示若干训练集的图片
#导入画图包和numpy
import matplotlib.pyplot as plt
import numpy as np

#构建展示图片的函数
def imShow(img):
    img=img/2+0.5
    #img是tensor类型的数据,tensor类型转换成numpy类型的数据
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))#维度转换成 1 2 0 这三个维度
    plt.show()

#从数据迭代器中读取一张图片
if __name__ == '__main__':
    try:
        dataiter = iter(trainloader)
        images, labels = dataiter.next()

        # 展示图片
        imShow(torchvision.utils.make_grid(images))
        # 使用网格的形式展示图片
        # 打印标签label
        print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
    except Exception as e:
        print(e)
#输出
dog horse   cat   cat

在这里插入图片描述

定义卷积神经网络

仿照2.1节中的类来构造此处的类,唯一的区别在于次数采用的三通道 3-channel

import torch.nn as nn
import torch.nn.functional as F

class Net
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值