背景:
在pytorch中使用MNIST数据集,进行可视化,代码如下:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
# part 1: 数据集的获取,torch中提供了数据集的相关API
mnist_train_dataset = datasets.MNIST(root="./data/",
train=True,
download=True,
transform=
transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5],std=[0.5]),transforms.Resize((28,28))])
)
mnist_test_dataset = datasets.MNIST(root="./data/",
train=False,
download=True,
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28,28)))
)
# part 2: 数据装载, dataloader
data_loader_train = torch.