nn.Module
是 PyTorch 中一个非常重要的基类,用于构建神经网络模型。它提供了一些基本功能和结构,使得构建、组合和管理模型变得更加方便和高效。具体来说,nn.Module
主要有以下几个用处:
封装模型的组件:通过继承
nn.Module
,可以将神经网络的不同层(如卷积层、全连接层等)封装在一起,便于管理和组织代码。参数管理:
nn.Module
自动注册其子模块(子层)中的所有参数,使得可以方便地获取和更新模型的参数。在调用model.parameters()
时,会返回所有可训练的参数。前向传播:重写
forward
方法定义前向传播的逻辑,使得可以根据输入数据计算输出结果。PyTorch 利用这一方法实现自动微分。模型序列化:
nn.Module
提供了保存(save
)和加载(load
)模型的方法,方便模型的持久化和复用。模块组合:可以方便地将多个模块组合成更复杂的模型,支持嵌套使用子模块,使得构建复杂的深度学习模型变得更加直观。
简单的示例代码如下:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
model = MyModel()
input_data = torch.randn(1, 10)
output = model(input_data)
print(output)
在这个示例中,MyModel
继承了 nn.Module
,定义了一个简单的前馈神经网络。通过这种方式,PyTorch 的模块化设计和管理变得简洁明了。
在这个类中有俩函数:初始化函数__init__(self) 和 前向传播函数forward(self,input)
写个简单的例子(代码实现过程写在了注释中):
import torch
from torch import nn
class Aaax(nn.Module): #Module表示一个类,首字母要大写
def __init__(self):
super().__init__()
def forward(self,input):
output = input+2
return output
demo=Aaax() #会调用super().__init__() ,完成Module的初始化
a=torch.tensor(1)
output=demo(a) #这里会把 a 传入forward中的input中
print(output) #输出tensor(3)