CNN手写数字识别demo——pyTorch代码实现

import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

EPOCH = 1                           # 迭代次数
BATCH_SIZE = 50                     # 每一批训练的数据大小
LR = 0.001                          # 学习率
DOWNLOAD_MNIST = False              # 是否需要下载数据集,首次运行需要把该参数指定为True

'''
torchvisions 是一个数据集的库,下载torch的时候一起下载的
MNIST 是手写数字的数据集
root 是数据集的根目录
train为True 代表的是下载训练数据,为False代表下载的是测试数据
transform 是指定数据格式,ToTensor是将图片转化为0到1之间的Tensor类型的数据
download 表示是否需要下载数据集,如果已经下载过,就设置为False
'''
train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

# # 查看数据size
# print(train_data.data.size())
# # 查看标签size
# print(train_data.targets.size())
# # 在画布上展示图片
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# # 设置标题,%i 代表转为有符号的十进制
# plt.title('%i'%train_data.targets[0])
# plt.show()

#  训练集分批
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# 下载测试集数据
test_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=False
)

'''
这里的test_data.data的size是(10000, 28, 28)
维度一 表示测试集总数
维度二 表示第多少个像素的横坐标对应的RGB值(这里已从三层(xxx,xxx,xxx)转为一层(k), k的值是0到1之间)
维度三 与参数二相似,表示的是纵坐标
添加维度(图片特征的高度)并切片后,test_x的size是(2000, 1, 28, 28)
PyTorch图层以Torch.FloatTensor作为输入,所以这里要指定一下torch.float32,否则在后续的训练中会报错。
最后除255是因为像素点是0-255范围的,这样可以做归一化处理,把特征值固定在0-1范围区间。如果不做
归一化处理,会在某些情况下(比如RNN神经网络)出现欠拟合的问题。
'''
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]/255
# 获取正确分类的结果
test_y = test_data.targets[:2000
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值