目录
一、nn.Module概述
nn.Module是PyTorch中所有神经网络模块的基类。我们自定义的模型以及PyTorch内置的模型(如nn.Linear, nn.Conv2d等)都是继承自该类。它提供了管理网络组件(层、参数等)的便捷方式,并且支持模型的保存、加载、在GPU上运行等功能。
二、nn.Module的特点
-
封装性:nn.Module可以将一系列操作封装成一个模块,模块中可以包含子模块(即其他nn.Module实例)、参数(nn.Parameter)和其他属性。
-
参数管理 :nn.Module会自动跟踪其内部的nn.Parameter对象,并且可以通过parameters()或named_parameters()方法访问它们,方便优化器进行更新。
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20) # 自动注册参数
self.linear2 = nn.Linear(20, 5) # 自动注册参数
def forward(self, x):
return self.linear2(self.linear1(x))
# 参数自动收集
model = SimpleNet()
print("参数数量:", sum(p.numel() for p in model.parameters()))
print("可训练参数数量:", sum(p.numel() for p in model.parameters() if p.requires_grad))
- 模块化设计:支持将复杂的网络分解成多个子模块,提高代码的可读性和复用性。
nn.Module遵循面向对象的设计原则,允许将神经网络分解为可重用的组件。这种模块化设计带来以下优势:
- 代码复用:可以像搭积木一样构建复杂网络
- 层次清晰:网络结构以树状形式组织,便于理解和调试
- 参数隔离:每个模块管理自己的参数,避免命名冲突
- 设备无关性:可以使用to(device)方法将模型的所有参数和缓冲区移动到指定的设备(CPU或GPU)上。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device) # 模型、参数全部移动到指定设备
- 保存和加载:提供了state_dict()和load_state_dict()方法,方便保存和加载模型状态。
# 保存模型
torch.save(model.state_dict(), "model.pth")
# 加载模型
model.load_state_dict(torch.load("model.pth"))
- 序列化支持
# 保存模型
torch.save(model.state_dict(), "model.pth")
# 加载模型
model.load_state_dict(torch.load("model.pth"))
- 钩子函数:允许注册前向传播和反向传播的钩子函数,用于调试和扩展功能。
class ModelWithHooks(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 5)
# 注册前向传播钩子
self.layer.register_forward_hook(
lambda module, input, output: print(f"Forward: {
output.shape}")
)
# 注册反向传播钩子
self.layer.register_full_backward_hook(
lambda module, grad_input, grad_output: print(f"Backward hook triggered")
)
三、 nn.Module的使用方式
- 自定义模型
我们通过继承nn.Module类,并实现__init__和forward方法来定义自己的模型。
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)

最低0.47元/天 解锁文章
4082

被折叠的 条评论
为什么被折叠?



