目录
2.1 torchvision.transforms的图形数据处理方法
2.1.1 torchvision.transforms.ToTensor
2.1.2 torchvision.transforms.Normalize(mean, std)
2.1.3 torchvision.transforms.Compose(transforms)
2.2 准备MNIST数据集的Dataset和DataLoader
1.思路和流程分析
流程:
-
准备数据,MNIST数据集可以使用Pytorch API加载,所以不需要自定义dataset类,但是需要准备DataLoader
-
构建模型,这里仅仅使用全连接层堆叠来实现MNIST手写数字识别,如果采用CNN等其他深度网络模型结构,只需要在模型定义那一步做出改变即可
-
训练模型,更新参数
-
模型的保存,保存模型,便于后续的使用
-
模型的评估,使用测试集,观察模型的好坏
2.准备训练集和测试集
调用MNIST返回的结果中图形数据是一个Image对象,需要对其进行处理,对于图像数据的处理,可以采用torchvision.transfroms的方法
2.1 torchvision.transforms的图形数据处理方法
2.1.1 torchvision.transforms.ToTensor
把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W]
其中(H,W,C)意思为(高,宽,通道数),黑白图片的通道数只有1,其中每个像素点的取值为[0,255],彩色图片的通道数为(R,G,B),每个通道的每个像素点的取值为[0,255],三个通道的颜色相互叠加,形成了各种颜色
示例如下:
from torchvision import transforms
import numpy as np
data = np.random.randint(0, 255, size=12)
img = data.reshape(2,2,3)
print(img.shape)
# 注意这种形式,ToTensor方法不能传入参数
img_tensor = transforms.ToTensor()(img) # 转换成tensor
print(img_tensor)
print(img_tensor.shape)
输入如下:
shape:(2, 2, 3)
img_tensor:tensor([[[215, 171],
[ 34, 12]],
[[229, 87],
[ 15, 237]],
[[ 10, 55],
[ 72, 204]]], dtype=torch.int32)
new shape:torch.Size([3, 2, 2])
2.1.2 torchvision.transforms.Normalize(mean, std)
给定均值:mean,shape和图片的通道数相同(指的是每个通道的均值),方差:std,和图片的通道数相同(指的是每个通道的方差),将会把Tensor规范化处理。
即:Normalized_image=(image-mean)/std
示例如下:
from torchvision import transforms
import numpy as np
import torchvision
data = np.random.randint(0, 255, size=12)
img = data.reshape(2,2,3)
img = transforms.ToTensor()(img) # 转换成tensor
print(img)
print("*"*100)
# 有三个通道就得传入一个长度为3的元组
norm_img = transforms.Normalize((10,10,10), (1,1,1))(img) #进行规范化处理
print(norm_img)
输出结果如下:
tensor([[[177, 223],
[ 71, 182]],
[[153, 120],
[173, 33]],
[[162, 233],
[194, 73]]], dtype=torch.int32)
***************************************************************************************
# 167就是(177-10)/1,其余的类似
tensor([[[167, 213],
[ 61, 172]],
[[143, 110],
[163, 23]],
[[152, 223],
[184, 63]]], dtype=torch.int32)
注意:在sklearn中,默认上式中的std和mean为数据每列的std和mean,sklearn会在标准化之前算出每一列的std和mean。
但是在api:Normalize中并没有帮我们计算,所以我们需要手动计算
-
当mean为全部数据的均值,std为全部数据的std的时候,才是进行了标准化。
-
如果mean(x)不是全部数据的mean的时候,std(y)也不是的时候,Normalize后的数据分布满足下面的关系:

2.1.3 torchvision.transforms.Compose(transforms)
该方法可以将多个transform组合起来使用。
transforms.Compose([
torchvision.transforms.ToTensor(), #先转化为Tensor
torchvision.transforms.Normalize(mean,std) #在进行正则化
])
2.2 准备MNIST数据集的Dataset和DataLoader
# 准备训练集
import torchvision
#准备数据集,其中0.1307,0.3081为MNIST数据的均值和标准差,这样操作能够对其进行标准化
#因为MNIST只有一个通道(黑白图片),所以元组中只有一个值
dataset = torchvision.datasets.MNIST('/data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
#准备数据迭代器
train_dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle=True)
# 准

本文详细介绍了使用PyTorch构建和训练一个简单的神经网络模型,用于识别MNIST数据集的手写数字。从数据预处理、模型构建、训练过程到模型保存与加载,以及模型评估,全方位解析了深度学习模型的实战步骤。通过三轮全连接层,实现了90%以上的识别准确率。
最低0.47元/天 解锁文章
1322

被折叠的 条评论
为什么被折叠?



