文章目录
1.引用
torchvision提供了一些常用的数据集、模型、转换函数等
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
2.内置图片数据集加载
torch的内置图片数据集均在datasets模块下,包含Catletch、CelebA、CIFAR、Cityscapes、COCO、Fashion-MNIST、ImageNet、MNIST等。
MNIST数据集是0-9手写数字数据集。
train=True表示是训练数据
torchvision.transforms包含了转换函数
这里用到了ToTensor类,该类的主要作用有以下3点:
①将输入转换成张量
②读取图片的格式规范为(channel,heigth,width)
③将图片像素的取值范围归一化0-1
train_ds=torchvision.datasets.MNIST('data/',train=True,transform=transforms.ToTensor(),download=True)
test_ds=torchvision.datasets.MNIST('data/',train=False,transform=transforms.ToTensor(),download=True)
3.处理为batch类型
DataLoader有以下4个目的:
①使用shuffle参数对数据集做乱序的操作(随机打乱)
②将数据采样为小批次,可用batch_size参数指定批次大小(小批次)
③可以充分利用多个子进程加速数据预处理(多线程)
④可通过collate_fn参数传递批次数据中的处理函数,实现对批次数据进行转换处理(转换处理)
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=64)
上述两行代码创建了DataLoader类型的train_dl和test_dl
DataLoader是可迭代对象,next方法返回一个批次的图像imgs和对应一个批次的标签labels
4.设置运行设备
机器学习或者深度学习需要选择程序运行的设备是CPU还是GPU,GPU就是通常所说的需要有显卡。
device='cuda' if torch.cuda.is_available() else 'cpu'
print('use {} device'.format(device))
5.查看数据
imgs,labels=next(iter(train_dl))
print(imgs.shape)
print(labels.shape)
结果:
torch.Size([64, 1, 28, 28])
torch.Size([64])
6.绘图查看数据图片
imgs[:10]查看前10条数据
np.squeeze从数组的形状中删除维度为 1 的维度。
np.unsqueeze从数组的形状中添加维度为 1 的维度。
注:只有数组长度在该维度上为 1,那么该维度才可以被删除。如果不是1,那么删除的话会报错
报错信息:cannot select an axis to squeeze out which has size not equal to one
(1)不显示图片标签
plt.figure(figsize=(10,1))
for i,img in enumerate(imgs[:10]):
npimg=img.numpy()
npimg=np.squeeze(npimg)#形状由(1,28,28)转换为(28,28)
plt.subplot(1,10,i+1)
plt.imshow(npimg) #在子图中绘制单张图片
plt.axis('off') #关闭显示子图坐标
print(labels[:10])
plt.show()
(2)打印图片标签
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
class_label_str=''
img_label_list=list(zip(imgs,labels))
for i,(img,label) in enumerate(img_label_list):
nimg=np.array(img)
nimg=np.squeeze(nimg)
plt.subplot(8,8,i+1)
plt.title(str(label.item()))
plt.imshow(nimg)
plt.axis('off')
'''按照图片显示格式打印所有标签:i!=0实现按行打印的同时第一行前面无空行,按每行8列打印'''
if i!=0 and i%8==0:
class_label_str +='\n'
class_label_str += classes[label.item()]+'\t'
print(class_label_str)
plt.show()
(3)图片显示标签
我这里以'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'这些类别示例,作用于手写字体图像分类时,要更改成0-9
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
img_label_list=list(zip(imgs,labels))
for i,(img,label) in enumerate(img_label_list):
nimg=img.transpose(0, 2)
nimg=nimg.numpy()
plt.subplot(5,5,i+1)
plt.title(classes[label.item()])
plt.imshow(nimg)
plt.axis('off')
plt.show()
7.定义卷积函数
定义卷积函数才是算法模型的真正开始,卷积层一般是必不可少的,是机器学习和深度学习的灵魂与基石所在。
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(1,6,5)
self.conv2=nn.Conv2d(6,16,5)
self.linear1 = nn.Linear(16*4*4,20)
self.linear2 = nn.Linear(20,10)
def forward(self,input):
x=torch.max_pool2d(torch.

最低0.47元/天 解锁文章
1485

被折叠的 条评论
为什么被折叠?



