2.2 使用Pytorch构建一个分类器
学习目标
- 了解分类器的任务和数据样式
- 掌握如何使用Pytorch实现一个分类器
分类器任务和数据介绍
- 构建一个将不同图像进行分类的神经网络分类器,对输入的图片进行判断并完成分类
- 本案例采用CIFAR0数据集作为原始图片数据
CIFAR10数据集介绍
- 数据集中每张图片的尺寸是33232,代表彩色3通道
- CIFAR10数据集共有10种不同的分类,分别是:airplane automobile bird cat deer dog frog horse ship truck
训练分类器步骤
- 1:使用torchvision下载CIFAR10数据集
- 2:定义卷积神经网络
- 3:定义损失函数
- 4:在训练集上训练模型
- 5:在测试集上测试模型
1:使用torchvision下载CIFAR10数据集
- 导入torchvision包来辅助下载数据集
import torch
import torchvision
import torchvision.transforms as transforms
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
#下载数据集并对图片进行调整,因为torchvision数据集的输出是PILImage格式,数据域在[0,1]
#我们将其转换成标准数据域[-1,1]的张量格式
#transform 数据转换器
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
# 下载下来的数据放在trainset里面
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
# DataLoader数据迭代器 将数据封装成DataLoader
# num_workers:两个线程读取数据
# batch_size=4 批处理
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
classes=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#输出
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
100.0%
Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified
注意:
可能存在的两个问题:
- windows系统下运行上述代码,并且出现报错信息“BrokenPipeError”时,可以尝试将torch.utils.DataLoader()中的num_workers设置为0
- 下载CIFAR10数据集报错:urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certi
- 尽量使用linux系统学习深度学习
展示若干训练集的图片
在此处遇到问题可以访问我之前写的一篇文章
import torch
import torchvision
import torchvision.transforms as transforms
import ssl
import torch.utils.data as Data
ssl._create_default_https_context = ssl._create_unverified_context
#下载数据集并对图片进行调整,因为torchvision数据集的输出是PILImage格式,数据域在[0,1]
#我们将其转换成标准数据域[-1,1]的张量格式
#transform 数据转换器
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
# 下载下来的数据放在trainset里面
trainloader=torch.utils.data.DataLoader(dataset=trainset,batch_size=4,shuffle=True,num_workers=2)
# DataLoader数据迭代器 将数据封装成DataLoader
# num_workers:两个线程读取数据
# batch_size=4 批处理
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(dataset=testset,batch_size=4,shuffle=False,num_workers=2)
classes=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#展示若干训练集的图片
#导入画图包和numpy
import matplotlib.pyplot as plt
import numpy as np
#构建展示图片的函数
def imShow(img):
img=img/2+0.5
#img是tensor类型的数据,tensor类型转换成numpy类型的数据
npimg=img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0)))#维度转换成 1 2 0 这三个维度
plt.show()
#从数据迭代器中读取一张图片
if __name__ == '__main__':
try:
dataiter = iter(trainloader)
images, labels = dataiter.next()
# 展示图片
imShow(torchvision.utils.make_grid(images))
# 使用网格的形式展示图片
# 打印标签label
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
except Exception as e:
print(e)
#输出
dog horse cat cat

定义卷积神经网络
仿照2.1节中的类来构造此处的类,唯一的区别在于次数采用的三通道 3-channel
import torch.nn as nn
import torch.nn.functional as F
class Net

本文详细介绍使用PyTorch构建图像分类器的过程,包括下载CIFAR10数据集、定义卷积神经网络、训练及测试模型等关键步骤。
最低0.47元/天 解锁文章
620

被折叠的 条评论
为什么被折叠?



