神经网络由执行数据操作的层/模块组成。torch.nn
命名空间提供了构建神经网络所需的所有基础模块。PyTorch中的每个模块都是nn.Module
的子类。神经网络本身也是一个模块,由其他模块(层)组成。这样的嵌套结构可以轻松地构建和管理复杂的网络架构。
在接下来的部分中,我们将构建一个神经网络,用于对FashionMNIST数据集中的图像进行分类。
代码引用头文件如下:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
获取训练设备
我们希望能够在 GPU 或 MPS 等硬件加速器上训练模型(如果可用)。请最好安装GPU版本的pytorch,因为GPU可将一个长达一小时的训练缩短至一分钟以内。以下代码用于检查 torch.cuda
或 torch.backends.mps
是否可用,若不可用将使用 CPU。
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")
定义类
我们通过继承 nn.Module
来定义神经网络,并在 __init__
方法中初始化神经网络的层。每个 nn.Module
子类都在 forward
方法中实现对输入数据的操作。
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__() # 调用父类 nn.Module 的构造函数
self.flatten = nn.Flatten() # 定义一个将输入展平(从二维28x28转为一维784)的层
# 定义一个顺序容器,包含一系列线性层和ReLU激活函数
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512), # 第一层:将784维的输入映射到512维
nn.ReLU(), # 激活函数:ReLU
nn.Linear(512, 512), # 第二层:将512维的输入映射到512维
nn.ReLU(), # 激活函数:ReLU
nn.Linear(512, 10), # 输出层:将512维的输入映射到10维(对应分类任务的10个类别)
)
def forward(self, x):
x = self.flatten(x) # 将输入 x 展平为一维
logits = self.linear_relu_stack(x) # 将展平后的输入传入线性+ReLU堆栈,得到logits
return logits # 返回分类结果(未经过Softmax处理的logits)
接下来从面向对象的角度讲解这段代码:
类与继承
在 Python 面向对象编程中,类是创建对象的蓝图。我们可以通过定义类来封装数据和行为。在这段代码中:
NeuralNetwork
是一个子类,它继承了nn.Module
(PyTorch 中所有神经网络模型的基类)。继承意味着NeuralNetwork
可以使用nn.Module
的方法和属性。- 通过继承
nn.Module
,NeuralNetwork
获得了 PyTorch 中的一些基础功能,比如模块的初始化、参数管理、模型保存和加载等。调用super().__init__()
就是为了初始化父类nn.Module
的功能。
总结:继承帮助 NeuralNetwork
类复用了 nn.Module
的功能,使得我们可以专注于实现特定的神经网络结构。
封装
封装 是面向对象的一个核心思想,它允许类将数据(属性)和行为(方法)包装在一起,并对外部隐藏实现细节。__init__
和 forward
方法都是封装的一部分。
__init__
方法:构造函数,用于初始化对象。在这里,我们在构造函数中定义了模型的结构(例如self.flatten</