Lenet5训练模型
下载数据集
可以提前下载也可以在线下载
train_data = torchvision.datasets.MNIST(root='./',download=True,train=True,transform=transform)
test_data = torchvision.datasets.MNIST(root='./',download=True,train=False,transform=transform)
训练模型
import torch
import torchvision
class Lenet5(torch.nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5), # 1*32*32 # 6*28*28
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, # 6*14*14
stride=2),
torch.nn.Conv2d(in_channels=6,
out_channels=16,
kernel_size=5), # 16 *10*10
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2,
stride=2), # 16*5*5
to