动手学深度学习(第四章 深度学习计算)

本章介绍了深度学习模型的构造方法,包括继承Module类、Sequential、ModuleList和ModuleDict的使用。还讲解了模型参数的访问、初始化、共享,以及自定义层的实现,最后讨论了模型的读取和存储。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

4.1 模型构造

  • 本章我们基于Module类的模型构造方法:它可以使得模型更加灵活

4.1.1 继承MODULE类来构造模型

  • Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类
  • 这里定义的MLP类重载Module类的_init_函数和forward函数,它们分别用于创建模型参数和定义向前计算(正向传播)
import torch
from torch import nn


class MLP(nn.Module):
    # 声明带有模型参数的层,这里声明了两个全连接层
    def __init__(self, **kwargs):
        # 调用MLP父类Block的构造函数来进行必要的初始化,这样再构造实例时还可以指定其他函数
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256) # 输入数据进行线性变换得到隐藏层
        self.act = nn.ReLU()              # 然后经过relu变化
        self.output = nn.Linear(256, 10)  # 最后再经过线性变化输出结果

    # 定义模型的向前计算,即如何根据x计算返回所需要的模型输出
    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

以上的MLP类中无须定义反向传播函数。系统将通过自动求梯度而自动生成反向传播所需的backward函数。

我们可以实例化MLP类得到模型变量net。下面的代码初始化net并传入输入数据X做一次前向计算。其中,net(X)会调用MLP继承自Module类的__call__函数,这个函数将调用MLP类定义的forward函数来完成前向计算。

x = torch.rand(2,784)
net = MLP()
print(net)
net(x)
MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)





tensor([[ 0.0840, -0.0177,  0.0372,  0.2128, -0.1802,  0.2334,  0.0339, -0.1636,
          0.0673,  0.0214],
        [ 0.1027,  0.1027, -0.0777,  0.1907, -0.1190,  0.1237, -0.0201, -0.1508,
          0.1732,  0.0766]], grad_fn=<AddmmBackward>)

注意,这里并没有将Module类命名为Layer(层)或者Model(模型)之类的名字,这是因为该类是一个可供自由组建的部件。它的子类既可以是一个层(如PyTorch提供的Linear类),又可以是一个模型(如这里定义的MLP类),或者是模型的一个部分。我们下面通过两个例子来展示它的灵活性

4.1.2 Module的子类

  • Module类是一个通用的部件,Pytorch还实现了继承自Module的可以方便构建模型的类:Sequential、ModuleList和ModuleDict

4.1.2.1 Sequential类

  • 当模型的向前计算为简单串联各个层的计算时,Sequential类可以通过更加简单的方式定义模型
  • 它可以介绍一个子模块的有序字典(OrderDict)或者一系列模块作为参数来逐一添加Module的实例,模型的计算就是按照添加顺序实现

下面实现一个与Sequential类有相同功能的MySequential类,让我们更加清晰地理解Sequential类的工作机制

class MySequential(nn.Module):
    from collections import OrderedDict
    def __init__(self,*args):
        super(MySequential,self).__init__()
        if len(args)==1 and isinstance(args[0],OrderedDict()): #判断传入的args长度为1并且是OrderedDict类型
            for key,module in args[0].item():
                self.add_module(key,module) #add_module方法会将module添加进self._module(一个OrderedDict)
        else: # 传入的是一些module
            for idx,module in enumerate(args):
                self.add_module(str(idx),module)
    def forward (self,input):
        #self._modules返回一个OrderedDict,保证会按照成员添加时的顺序遍历
        for module in self._modules.values():
            input = module(input)
            return input

我们用MySequential类来实现前面描述的MLP类,并使用随机初始化的模型做一次前向计算。

net = MySequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
print(net)
net(x)
MySequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)





tensor([[ 9.6760e-02,  6.1526e-01, -2.1479e-01, -2.4838e-01, -1.8044e-01,
          1.0633e-01, -4.3294e-02,  1.5417e-01,  2.6871e-01, -2.3465e-01,
         -1.7879e-01,  5.9953e-02,  1.7104e-01,  5.0568e-01, -2.8405e-01,
         -3.0190e-01, -3.2859e-01, -3.7216e-01,  3.6683e-01,  1.3583e-01,
         -1.4334e-01, -5.9070e-01, -2.3902e-01,  8.6598e-02, -3.2121e-01,
         -1.4389e-01,  4.8549e-01,  1.8579e-01, -5.2170e-02, -1.5570e-02,
          1.1181e-01, -1.9714e-01,  2.5036e-01, -5.8424e-02, -1.3689e-02,
          3.6068e-01, -4.7498e-01,  4.4805e-01,  2.6055e-01, -1.0720e-01,
          2.8770e-01, -1.0109e-01,  4.5815e-01, -4.0124e-01,  2.5889e-01,
         -3.3336e-02,  6.0167e-01,  1.8990e-01, -5.8802e-02,  6.8722e-01,
         -3.3514e-01, -9.3957e-01, -2.8368e-01,  3.1755e-01,  7.4712e-02,
          5.6128e-01, -2.1928e-02,  6.3363e-03, -7.1676e-01, -5.8925e-01,
         -7.0805e-02, -3.2069e-01, -3.3287e-01, -1.5045e-01, -6.0647e-01,
         -9.4337e-02,  4.6517e-01,  2.5196e-01,  2.9131e-01,  6.4817e-01,
         -1.7585e-01,  4.3723e-01,  1.0636e-01,  7.5438e-01,  5.5950e-01,
         -2.4501e-01, -5.7343e-01, -3.1766e-01,  6.8117e-02, -6.5833e-01,
         -6.0973e-01,  2.7250e-01, -1.4511e-01,  2.7163e-01, -5.4788e-01,
         -2.9311e-02,  4.2360e-01, -4.0402e-01, -3.1681e-01, -3.5965e-01,
          3.2306e-01, -7.2167e-02,  6.1791e-01,  8.5589e-02, -6.8449e-02,
          9.2522e-02,  1.7102e-01,  3.8423e-01, -1.0283e-01,  4.5553e-01,
          4.5290e-01, -1.6429e-01,  2.6184e-02, -3.3518e-01, -1.4742e-01,
         -1.0841e-01,  2.6545e-01,  1.9037e-01, -1.7716e-01,  1.4499e-01,
          2.0683e-01, -1.6765e-02, -4.0886e-01, -1.4158e-03, -2.4529e-01,
          4.4866e-01,  1.9679e-01,  1.5016e-01,  5.2420e-01,  9.7283e-01,
          1.3504e-01, -1.7518e-01,  5.9695e-01, -2.2792e-01, -3.1842e-01,
         -2.8035e-01,  2.2754e-01, -2.5845e-01, -3.7324e-02, -2.8803e-02,
         -3.4054e-01, -4.8280e-01,  2.2013e-01,  4.2498e-01,  1.3680e-01,
         -4.6749e-01, -1.3055e-01, -5.7328e-01, -4.9055e-01, -2.2279e-01,
          4.7637e-02, -5.7239e-01,  7.9942e-02, -3.0113e-01, -4.4272e-01,
          2.4327e-01, -2.0071e-01,  2.6980e-01,  1.7690e-01, -1.4942e-01,
          1.0565e-01,  5.8500e-02, -4.6605e-03,  7.9855e-01, -4.6251e-02,
          2.8216e-01, -3.4840e-01, -3.5109e-01, -1.5388e-01,  4.0279e-01,
         -1.0993e-03,  1.0501e-02, -9.4468e-02,  1.5324e-01,  4.7143e-02,
         -5.6766e-01, -2.3478e-01,  1.1706e-01, -3.9058e-01,  6.1886e-03,
          8.6564e-02,  3.1441e-02, -1.4241e-01,  2.0486e-01,  6.1886e-03,
          1.7367e-01, -8.8623e-02, -2.5678e-01,  3.1567e-01, -8.5611e-01,
         -2.4832e-01,  3.3563e-01,  1.5656e-02, -3.5225e-01,  1.7295e-02,
         -9.2993e-02,  6.5053e-01,  8.0615e-02, -1.5109e-01,  2.0775e-01,
          2.5568e-01, -2.0957e-01, -3.3211e-01, -6.3330e-01,  2.8433e-01,
         -1.5971e-01,  7.5023e-03,  4.4780e-01, -3.9100e-01, -4.1025e-01,
         -5.7477e-02, -8.6253e-02,  3.8070e-01, -2.8796e-02, -1.8143e-01,
          6.0662e-01, -3.5896e-01,  2.8587e-01, -5.6008e-01,  3.0484e-01,
          4.5223e-01,  4.1040e-02,  2.4319e-02, -3.3379e-02,  2.9209e-01,
         -2.8169e-01,  1.4224e-02,  7.1399e-02, -3.3856e-02,  1.1149e-01,
         -2.1588e-01,  1.2100e-02, -7.3755e-02,  1.8370e-01, -2.8106e-01,
          3.1099e-01,  3.9101e-01, -1.6609e-01, -1.1564e-01, -3.2101e-01,
         -2.5801e-01, -3.4583e-01, -2.2928e-01,  5.4415e-01,  1.7075e-01,
          4.7242e-01, -1.3852e-01, -2.6206e-01, -2.8216e-01,  4.5729e-02,
          5.3571e-01,  2.0759e-01, -4.7854e-02,  4.7861e-01,  9.9984e-02,
          1.3719e-01,  6.1629e-01,  5.1949e-01, -1.6346e-01,  5.5077e-01,
         -2.0594e-01, -6.3007e-02, -4.7834e-01, -1.7159e-01, -3.5195e-02,
         -3.3245e-01],
        [ 2.4010e-01,  3.0471e-01, -5.2543e-01, -2.6730e-01, -3.3221e-02,
          1.9604e-01,  1.3777e-01,  3.4451e-01,  3.4591e-01, -2.6930e-01,
          1.3081e-02,  5.6063e-01,  2.5105e-01,  2.2183e-01, -5.9729e-02,
         -3.0842e-01, -3.8761e-01, -3.7497e-01,  3.1502e-01,  2.3410e-02,
         -3.5446e-02, -1.4785e-01, -6.3958e-01,  3.0649e-01, -1.6719e-01,
         -3.5391e-01,  4.7776e-01,  6.4315e-01,  1.3822e-01, -3.4816e-01,
          3.9006e-02, -4.2801e-01,  2.4051e-02, -9.2707e-02,  6.2816e-02,
          4.1986e-01, -4.2527e-01,  5.2899e-01,  5.9827e-01, -1.5147e-01,
          4.1955e-01, -3.5571e-02,  1.9278e-01, -8.0548e-02,  2.4398e-01,
         -2.8606e-01,  7.6591e-01,  4.3425e-02,  2.5114e-01,  3.8419e-01,
         -3.3522e-01, -1.2938e+00, -3.2376e-01,  3.5667e-01, -5.2941e-02,
          6.0979e-01,  1.5105e-02,  3.0921e-02, -2.8994e-01, -4.2888e-01,
          8.4356e-02, -5.8704e-01, -2.1208e-01, -4.2752e-02, -4.8370e-01,
         -1.1967e-01,  6.2766e-01,  3.0056e-02,  2.4747e-01,  3.9887e-01,
         -1.5808e-01,  2.9959e-01,  6.6251e-01,  7.0864e-01,  5.6218e-01,
         -2.1177e-02, -1.9826e-01, -1.9828e-01,  2.1884e-01, -5.1883e-01,
         -3.9410e-02,  4.5899e-01,  2.0575e-01, -1.2067e-01, -3.3744e-01,
         -2.9476e-01,  5.1977e-02, -1.7410e-01, -9.9014e-03, -4.7142e-01,
          2.6167e-01, -3.6820e-01,  1.1605e-01, -1.8416e-01, -2.2039e-02,
          2.8389e-01,  7.8144e-02,  7.1102e-02, -1.0808e-01,  2.3506e-01,
          3.4647e-01, -2.2646e-01,  1.2892e-02,  5.4415e-02,  1.9749e-01,
         -2.6919e-01,  2.0519e-01,  5.1301e-01, -2.5568e-01,  6.9206e-02,
          1.1940e-01, -9.6597e-02, -4.3718e-01, -3.2612e-01, -2.9065e-01,
          1.1352e-01,  1.0605e-01,  4.2826e-01,  4.5905e-01,  6.6613e-01,
         -8.4032e-02, -2.5239e-01,  5.4654e-01, -3.7485e-01, -9.8932e-02,
         -1.6790e-01,  2.7925e-01, -1.2701e-01,  1.0140e-01,  2.7005e-01,
         -2.7344e-01, -6.0573e-01,  2.3084e-01,  3.3079e-01,  1.8777e-01,
         -5.1973e-01, -2.6797e-01, -1.0291e+00, -3.6658e-01, -2.9196e-01,
          3.2473e-02, -6.0835e-01, -8.1029e-02, -3.4202e-01,  1.9162e-02,
          1.7709e-01, -2.7436e-01,  5.4803e-01,  2.4869e-01, -1.4216e-01,
          3.1039e-01,  3.1186e-01, -4.0879e-01,  5.9306e-01, -1.87
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值