1. 继承 torch.nn.Module
类(推荐方法)
最常见和推荐的方式是通过继承 torch.nn.Module
类来创建一个自定义的神经网络模型。在这种方式下,你需要定义 __init__()
方法来初始化网络层,并在 forward()
方法中定义前向传播逻辑。
示例:一个简单的全连接神经网络
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
# 定义网络层
self.fc1 = nn.Linear(784, 128) # 输入层:28x28 图像展平为 784
self.fc2 = nn.Linear(128, 64) # 隐藏层
self.fc3 = nn.Linear(64, 10) # 输出层:10 类分类
# 激活函数
self.relu = nn.ReLU()
def forward(self, x):
# 前向传播逻辑
x = self.relu(self.fc1(x)) # 输入 -> 第一层 -> 激活
x = self.relu(self.fc2(x)) # 第二层 -> 激活
x = self.fc3(x) # 输出层
return x
# 创建模型实例
model = SimpleNN()
print(model)
解释:
__init__()
:在这个方法中定义了神经网络的层(如nn.Linear
),并且可以定义激活函数(如nn.ReLU()
)。forward()