跟着李沐学AI代码复现--01LeNet,02ALexNet

import torch
from torch import nn
from torch.utils import data

import torchvision
from torchvision import transforms

'''
这是我自己参考了李沐的代码
和
https://blog.youkuaiyun.com/alionsss/article/details/129973035
从而复现的LeNet
'''

# 超参数定义
batch_size = 256
epoch_num = 20
lr = 0.03
if torch.cuda.device_count() >= 1:
    device = torch.device(f'cuda:0')
else:
    device = torch.device('cpu')


def load():
    # 加载数据集
    # 这里的root指的是本地目录。
    mnist_train = torchvision.datasets.FashionMNIST(
        root="data", train=True, transform=transforms.ToTensor(), download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="data", train=False, transform=transforms.ToTensor(), download=True)
    # 将数据集转为特定格式 
    train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=8)
    test_iter = data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=8)

    return train_iter, test_iter


def netdefine():
    # 网络定义
    net = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2, stride=2),
 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值