pytorh之lenet的实现(CIFAR10数据集)
import torch as t
import matplotlib.pyplot as plt
import torchvision as tv
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor转成Image,方便可视化
第一次运行程序torchvision会自动下载CIFAR-10数据集,
大约100M,需花费一定的时间,
如果已经下载有CIFAR-10,可通过root参数指定
#定义对数据的预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
])
# 训练集
trainset = tv.datasets.CIFAR10(
root='./data/cifar/',
train=True,
download=True,
transform=transform)
trainloader = t.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
# 测试集
testset = tv.datasets.CIFAR10(
'./data/cifar/',