4.1 模型构造
- 本章我们基于Module类的模型构造方法:它可以使得模型更加灵活
4.1.1 继承MODULE类来构造模型
- Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类
- 这里定义的MLP类重载Module类的_init_函数和forward函数,它们分别用于创建模型参数和定义向前计算(正向传播)
import torch
from torch import nn
class MLP(nn.Module):
# 声明带有模型参数的层,这里声明了两个全连接层
def __init__(self, **kwargs):
# 调用MLP父类Block的构造函数来进行必要的初始化,这样再构造实例时还可以指定其他函数
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256) # 输入数据进行线性变换得到隐藏层
self.act = nn.ReLU() # 然后经过relu变化
self.output = nn.Linear(256, 10) # 最后再经过线性变化输出结果
# 定义模型的向前计算,即如何根据x计算返回所需要的模型输出
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
以上的MLP类中无须定义反向传播函数。系统将通过自动求梯度而自动生成反向传播所需的backward函数。
我们可以实例化MLP类得到模型变量net。下面的代码初始化net并传入输入数据X做一次前向计算。其中,net(X)会调用MLP继承自Module类的__call__函数,这个函数将调用MLP类定义的forward函数来完成前向计算。
x = torch.rand(2,784)
net = MLP()
print(net)
net(x)
MLP(
(hidden): Linear(in_features=784, out_features=256, bias=True)
(act): ReLU()
(output): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[ 0.0840, -0.0177, 0.0372, 0.2128, -0.1802, 0.2334, 0.0339, -0.1636,
0.0673, 0.0214],
[ 0.1027, 0.1027, -0.0777, 0.1907, -0.1190, 0.1237, -0.0201, -0.1508,
0.1732, 0.0766]], grad_fn=<AddmmBackward>)
注意,这里并没有将Module类命名为Layer(层)或者Model(模型)之类的名字,这是因为该类是一个可供自由组建的部件。它的子类既可以是一个层(如PyTorch提供的Linear类),又可以是一个模型(如这里定义的MLP类),或者是模型的一个部分。我们下面通过两个例子来展示它的灵活性
4.1.2 Module的子类
- Module类是一个通用的部件,Pytorch还实现了继承自Module的可以方便构建模型的类:Sequential、ModuleList和ModuleDict
4.1.2.1 Sequential类
- 当模型的向前计算为简单串联各个层的计算时,Sequential类可以通过更加简单的方式定义模型
- 它可以介绍一个子模块的有序字典(OrderDict)或者一系列模块作为参数来逐一添加Module的实例,模型的计算就是按照添加顺序实现
下面实现一个与Sequential类有相同功能的MySequential类,让我们更加清晰地理解Sequential类的工作机制
class MySequential(nn.Module):
from collections import OrderedDict
def __init__(self,*args):
super(MySequential,self).__init__()
if len(args)==1 and isinstance(args[0],OrderedDict()): #判断传入的args长度为1并且是OrderedDict类型
for key,module in args[0].item():
self.add_module(key,module) #add_module方法会将module添加进self._module(一个OrderedDict)
else: # 传入的是一些module
for idx,module in enumerate(args):
self.add_module(str(idx),module)
def forward (self,input):
#self._modules返回一个OrderedDict,保证会按照成员添加时的顺序遍历
for module in self._modules.values():
input = module(input)
return input
我们用MySequential类来实现前面描述的MLP类,并使用随机初始化的模型做一次前向计算。
net = MySequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
print(net)
net(x)
MySequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[ 9.6760e-02, 6.1526e-01, -2.1479e-01, -2.4838e-01, -1.8044e-01,
1.0633e-01, -4.3294e-02, 1.5417e-01, 2.6871e-01, -2.3465e-01,
-1.7879e-01, 5.9953e-02, 1.7104e-01, 5.0568e-01, -2.8405e-01,
-3.0190e-01, -3.2859e-01, -3.7216e-01, 3.6683e-01, 1.3583e-01,
-1.4334e-01, -5.9070e-01, -2.3902e-01, 8.6598e-02, -3.2121e-01,
-1.4389e-01, 4.8549e-01, 1.8579e-01, -5.2170e-02, -1.5570e-02,
1.1181e-01, -1.9714e-01, 2.5036e-01, -5.8424e-02, -1.3689e-02,
3.6068e-01, -4.7498e-01, 4.4805e-01, 2.6055e-01, -1.0720e-01,
2.8770e-01, -1.0109e-01, 4.5815e-01, -4.0124e-01, 2.5889e-01,
-3.3336e-02, 6.0167e-01, 1.8990e-01, -5.8802e-02, 6.8722e-01,
-3.3514e-01, -9.3957e-01, -2.8368e-01, 3.1755e-01, 7.4712e-02,
5.6128e-01, -2.1928e-02, 6.3363e-03, -7.1676e-01, -5.8925e-01,
-7.0805e-02, -3.2069e-01, -3.3287e-01, -1.5045e-01, -6.0647e-01,
-9.4337e-02, 4.6517e-01, 2.5196e-01, 2.9131e-01, 6.4817e-01,
-1.7585e-01, 4.3723e-01, 1.0636e-01, 7.5438e-01, 5.5950e-01,
-2.4501e-01, -5.7343e-01, -3.1766e-01, 6.8117e-02, -6.5833e-01,
-6.0973e-01, 2.7250e-01, -1.4511e-01, 2.7163e-01, -5.4788e-01,
-2.9311e-02, 4.2360e-01, -4.0402e-01, -3.1681e-01, -3.5965e-01,
3.2306e-01, -7.2167e-02, 6.1791e-01, 8.5589e-02, -6.8449e-02,
9.2522e-02, 1.7102e-01, 3.8423e-01, -1.0283e-01, 4.5553e-01,
4.5290e-01, -1.6429e-01, 2.6184e-02, -3.3518e-01, -1.4742e-01,
-1.0841e-01, 2.6545e-01, 1.9037e-01, -1.7716e-01, 1.4499e-01,
2.0683e-01, -1.6765e-02, -4.0886e-01, -1.4158e-03, -2.4529e-01,
4.4866e-01, 1.9679e-01, 1.5016e-01, 5.2420e-01, 9.7283e-01,
1.3504e-01, -1.7518e-01, 5.9695e-01, -2.2792e-01, -3.1842e-01,
-2.8035e-01, 2.2754e-01, -2.5845e-01, -3.7324e-02, -2.8803e-02,
-3.4054e-01, -4.8280e-01, 2.2013e-01, 4.2498e-01, 1.3680e-01,
-4.6749e-01, -1.3055e-01, -5.7328e-01, -4.9055e-01, -2.2279e-01,
4.7637e-02, -5.7239e-01, 7.9942e-02, -3.0113e-01, -4.4272e-01,
2.4327e-01, -2.0071e-01, 2.6980e-01, 1.7690e-01, -1.4942e-01,
1.0565e-01, 5.8500e-02, -4.6605e-03, 7.9855e-01, -4.6251e-02,
2.8216e-01, -3.4840e-01, -3.5109e-01, -1.5388e-01, 4.0279e-01,
-1.0993e-03, 1.0501e-02, -9.4468e-02, 1.5324e-01, 4.7143e-02,
-5.6766e-01, -2.3478e-01, 1.1706e-01, -3.9058e-01, 6.1886e-03,
8.6564e-02, 3.1441e-02, -1.4241e-01, 2.0486e-01, 6.1886e-03,
1.7367e-01, -8.8623e-02, -2.5678e-01, 3.1567e-01, -8.5611e-01,
-2.4832e-01, 3.3563e-01, 1.5656e-02, -3.5225e-01, 1.7295e-02,
-9.2993e-02, 6.5053e-01, 8.0615e-02, -1.5109e-01, 2.0775e-01,
2.5568e-01, -2.0957e-01, -3.3211e-01, -6.3330e-01, 2.8433e-01,
-1.5971e-01, 7.5023e-03, 4.4780e-01, -3.9100e-01, -4.1025e-01,
-5.7477e-02, -8.6253e-02, 3.8070e-01, -2.8796e-02, -1.8143e-01,
6.0662e-01, -3.5896e-01, 2.8587e-01, -5.6008e-01, 3.0484e-01,
4.5223e-01, 4.1040e-02, 2.4319e-02, -3.3379e-02, 2.9209e-01,
-2.8169e-01, 1.4224e-02, 7.1399e-02, -3.3856e-02, 1.1149e-01,
-2.1588e-01, 1.2100e-02, -7.3755e-02, 1.8370e-01, -2.8106e-01,
3.1099e-01, 3.9101e-01, -1.6609e-01, -1.1564e-01, -3.2101e-01,
-2.5801e-01, -3.4583e-01, -2.2928e-01, 5.4415e-01, 1.7075e-01,
4.7242e-01, -1.3852e-01, -2.6206e-01, -2.8216e-01, 4.5729e-02,
5.3571e-01, 2.0759e-01, -4.7854e-02, 4.7861e-01, 9.9984e-02,
1.3719e-01, 6.1629e-01, 5.1949e-01, -1.6346e-01, 5.5077e-01,
-2.0594e-01, -6.3007e-02, -4.7834e-01, -1.7159e-01, -3.5195e-02,
-3.3245e-01],
[ 2.4010e-01, 3.0471e-01, -5.2543e-01, -2.6730e-01, -3.3221e-02,
1.9604e-01, 1.3777e-01, 3.4451e-01, 3.4591e-01, -2.6930e-01,
1.3081e-02, 5.6063e-01, 2.5105e-01, 2.2183e-01, -5.9729e-02,
-3.0842e-01, -3.8761e-01, -3.7497e-01, 3.1502e-01, 2.3410e-02,
-3.5446e-02, -1.4785e-01, -6.3958e-01, 3.0649e-01, -1.6719e-01,
-3.5391e-01, 4.7776e-01, 6.4315e-01, 1.3822e-01, -3.4816e-01,
3.9006e-02, -4.2801e-01, 2.4051e-02, -9.2707e-02, 6.2816e-02,
4.1986e-01, -4.2527e-01, 5.2899e-01, 5.9827e-01, -1.5147e-01,
4.1955e-01, -3.5571e-02, 1.9278e-01, -8.0548e-02, 2.4398e-01,
-2.8606e-01, 7.6591e-01, 4.3425e-02, 2.5114e-01, 3.8419e-01,
-3.3522e-01, -1.2938e+00, -3.2376e-01, 3.5667e-01, -5.2941e-02,
6.0979e-01, 1.5105e-02, 3.0921e-02, -2.8994e-01, -4.2888e-01,
8.4356e-02, -5.8704e-01, -2.1208e-01, -4.2752e-02, -4.8370e-01,
-1.1967e-01, 6.2766e-01, 3.0056e-02, 2.4747e-01, 3.9887e-01,
-1.5808e-01, 2.9959e-01, 6.6251e-01, 7.0864e-01, 5.6218e-01,
-2.1177e-02, -1.9826e-01, -1.9828e-01, 2.1884e-01, -5.1883e-01,
-3.9410e-02, 4.5899e-01, 2.0575e-01, -1.2067e-01, -3.3744e-01,
-2.9476e-01, 5.1977e-02, -1.7410e-01, -9.9014e-03, -4.7142e-01,
2.6167e-01, -3.6820e-01, 1.1605e-01, -1.8416e-01, -2.2039e-02,
2.8389e-01, 7.8144e-02, 7.1102e-02, -1.0808e-01, 2.3506e-01,
3.4647e-01, -2.2646e-01, 1.2892e-02, 5.4415e-02, 1.9749e-01,
-2.6919e-01, 2.0519e-01, 5.1301e-01, -2.5568e-01, 6.9206e-02,
1.1940e-01, -9.6597e-02, -4.3718e-01, -3.2612e-01, -2.9065e-01,
1.1352e-01, 1.0605e-01, 4.2826e-01, 4.5905e-01, 6.6613e-01,
-8.4032e-02, -2.5239e-01, 5.4654e-01, -3.7485e-01, -9.8932e-02,
-1.6790e-01, 2.7925e-01, -1.2701e-01, 1.0140e-01, 2.7005e-01,
-2.7344e-01, -6.0573e-01, 2.3084e-01, 3.3079e-01, 1.8777e-01,
-5.1973e-01, -2.6797e-01, -1.0291e+00, -3.6658e-01, -2.9196e-01,
3.2473e-02, -6.0835e-01, -8.1029e-02, -3.4202e-01, 1.9162e-02,
1.7709e-01, -2.7436e-01, 5.4803e-01, 2.4869e-01, -1.4216e-01,
3.1039e-01, 3.1186e-01, -4.0879e-01, 5.9306e-01, -1.87