import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
EPOCH = 1 # 迭代次数
BATCH_SIZE = 50 # 每一批训练的数据大小
LR = 0.001 # 学习率
DOWNLOAD_MNIST = False # 是否需要下载数据集,首次运行需要把该参数指定为True
'''
torchvisions 是一个数据集的库,下载torch的时候一起下载的
MNIST 是手写数字的数据集
root 是数据集的根目录
train为True 代表的是下载训练数据,为False代表下载的是测试数据
transform 是指定数据格式,ToTensor是将图片转化为0到1之间的Tensor类型的数据
download 表示是否需要下载数据集,如果已经下载过,就设置为False
'''
train_data = torchvision.datasets.MNIST(
root='./mnist',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST
)
# # 查看数据size
# print(train_data.data.size())
# # 查看标签size
# print(train_data.targets.size())
# # 在画布上展示图片
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# # 设置标题,%i 代表转为有符号的十进制
# plt.title('%i'%train_data.targets[0])
# plt.show()
# 训练集分批
train_loader = Data.DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True,
)
# 下载测试集数据
test_data = torchvision.datasets.MNIST(
root='./mnist',
train=False
)
'''
这里的test_data.data的size是(10000, 28, 28)
维度一 表示测试集总数
维度二 表示第多少个像素的横坐标对应的RGB值(这里已从三层(xxx,xxx,xxx)转为一层(k), k的值是0到1之间)
维度三 与参数二相似,表示的是纵坐标
添加维度(图片特征的高度)并切片后,test_x的size是(2000, 1, 28, 28)
PyTorch图层以Torch.FloatTensor作为输入,所以这里要指定一下torch.float32,否则在后续的训练中会报错。
最后除255是因为像素点是0-255范围的,这样可以做归一化处理,把特征值固定在0-1范围区间。如果不做
归一化处理,会在某些情况下(比如RNN神经网络)出现欠拟合的问题。
'''
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]/255
# 获取正确分类的结果
test_y = test_data.targets[:2000
CNN手写数字识别demo——pyTorch代码实现
最新推荐文章于 2025-03-05 14:16:43 发布