文章目录
RNN实现MNIST手写数据识别
一、代码
import torch
import torchvision.datasets
from torch import nn
import torch.utils.data as Data
EPOCH=1 #训练多少次
BATCH_SIZE =64 #批训练数量
TIME_STEP=28 #nn时间步数/图片高度
INPUT_SIZE=28 #nn每步输入值/图片每行像素
LR=0.01 #学习率
DOWNLOAD_MNIST=False #mnist数据
#批训练
train_data = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MNIST)
train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
#测试
test_data = torchvision.datasets.MNIST(root='./mnist'<