PyTorch代码识别手写数字
import torch
import torchvision
from torchvision import datasets, transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))])
trainset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainsetloader = torch.utils.data.DataLoader(trainset, batch_size=20000, shuffle=True)
testset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
testsetloader = torch.utils.data.DataLoader(testset, batch_size=20000, shuffle=True)
first_in, first_out, second_out = 28*28, 128, 10
model = torch.nn