模型搭建
我们这次使用LeNet模型,LeNet是一个经典的卷积神经网络(Convolutional Neural Network, CNN)架构,最初由Yann LeCun等人在1998年提出,用于手写数字识别任务
创建一个文件model.py。实现以下代码。
源码
# 导入PyTorch库
import torch
# 从PyTorch库中导入神经网络模块
from torch import nn
# 从torchsummary库中导入summary函数,用于打印模型的结构和参数数量
from torchsummary import summary
# 定义LeNet类,它继承自nn.Module,是一个神经网络模型
class LeNet(nn.Module):
# 初始化函数,定义模型的层次结构
def __init__(self):
# 调用父类的初始化函数
super().__init__()
# 第一个卷积层,输入通道为1,输出通道为6,卷积核大小为5x5,padding为2
self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
# Sigmoid激活函数
self.sig = nn.Sigmoid()
# 第一个平均池化层,池化窗口为2x2,步长为2
self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
# 第二个卷积层,输入通道为6,输出通道为16,卷积核大小为5x5
self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
# 第二个平均池化层,池化窗口为2x2,步长为2
self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
# Flatten层,用于将多维的输入一维化,以便输入到全连接层
self.flatten = nn.Flatten()
# 第一个全连接层,输入特征数为400,输出特征数为120
self.f5 = nn.Linear(400, 120)
# 第二个全连接层,输入特征数为120,输出特征数为84
self.f6 = nn.Linear(120, 84)
# 第三个全连接层,输入特征数为84,输出特征数为10(通常对应分类任务中的类别数)
self.f7 = nn.Linear(84, 10)
# 前向传播函数,定义数据通过网络的方式
def forward(self, x):
x = self.sig(self.c1(x)) # 通过第一个卷积层和Si