参考莫凡博客进行MNIST数据集的训练,临时记录所使用的代码。
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
torch.manual_seed(1)
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True
train_data = torchvision.datasets.MNIST(
root='./mnist',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
test_data = torchvision.datasets.MNIST(root='./mnist', train=False)
# 批处理
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=