import torch
from matplotlib import pyplot as plt
from torch import nn, optim
# from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from matplotlib.ticker import MaxNLocator
# 超参数
batch_size = 256 # 批大小
learning_rate = 0.0001 # 学习率
epochs = 20 # 迭代次数
channels = 1 # 图像通道大小
# 数据集下载和预处理
transform = transforms.Compose([transforms.ToTensor(), # 将图片转换成PyTorch中处理的对象Tensor,并且进行标准化0-1
transforms.Normalize([0.5], [0.5])]) # 归一化处理
path = './data/' # 数据集下载后保存的目录
# 下载训练集和测试集
trainData = datasets.MNIST(path, train=True, transform=transform, download=True)
testData = datasets.MNIST(path, train=False, transform=transform)
# 处理成data loader
trainDataLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True) # 批量读取并打乱
testDataLoader = torch.utils.data.DataLoader(dataset=testData, batch_size=batch_size)
# 开始构建cnn模型
class cnn(torch.nn.Module):
def __init__(self):
super(cnn, self).__init__()
self.model = torch.nn.Sequential(
# The size of the picture is 28*28
torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Ma
PyTorch入门——实现MNIST分类
于 2022-04-28 09:59:04 首次发布