import torch
from torch import nn
from torch.utils import data
import torchvision
from torchvision import transforms
'''
这是我自己参考了李沐的代码
和
https://blog.youkuaiyun.com/alionsss/article/details/129973035
从而复现的LeNet
'''
# 超参数定义
batch_size = 256
epoch_num = 20
lr = 0.03
if torch.cuda.device_count() >= 1:
device = torch.device(f'cuda:0')
else:
device = torch.device('cpu')
def load():
# 加载数据集
# 这里的root指的是本地目录。
mnist_train = torchvision.datasets.FashionMNIST(
root="data", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="data", train=False, transform=transforms.ToTensor(), download=True)
# 将数据集转为特定格式
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=8)
test_iter = data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=8)
return train_iter, test_iter
def netdefine():
# 网络定义
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
跟着李沐学AI代码复现--01LeNet,02ALexNet
于 2024-07-13 16:05:33 首次发布