import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
batch_size = 50
# MNIST dataset
train_dataset = dsets.MNIST(root='../../data_sets/mnist', # 选择数据的根目录
train=True, # 选择训练集
transform=transforms.ToTensor(), # 转换成tensor变量
download=False) # 不从网络上download图片
test_dataset = dsets.MNIST(root='../../data_sets/mnist', # 选择数据的根目录
train=False, # 选择训练集
transform=transforms.ToTensor(), # 转换成tensor变量
download=False) # 不从网络上download图片
# 加载数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,