MXNet的训练基础脚本:base_module.py

本文介绍了MXNet中base_module.py脚本的重要性,特别是BaseModule类。BaseModule类为模型实现提供了框架,并且是mxnet.module.Module类的基类。重点讨论了fit()方法以及其他未实现但引发异常的方法,这些方法需要在子类中具体实现。了解这些基础将有助于深入理解MXNet的模型训练过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

写在前面:在MXNet中有一个很重要的脚本:base_module.py,这个脚本中的BaseModule类定义了和模型实现相关的框架。另外还有一个脚本module.py会在另外一篇博客中讲,这个类继承BaseModule类,并进行了具体的实现,比如forward方法,update_metric方法等。

在创建模型的时候用的是mx.mod.Module这个类,比如:model = mx.mod.Module(symbol = sym),具体这个Module类的内容是什么?其实mx.mod.Module准确地写应该是mxnet.module.Module(),这个类的路径是~/mxnet/python/mxnet/module/module.py,但是在实际调用的时候是调用python库中的module.py,路径是:/usr/local/lib/python2.7/dist-packages/mxnet-0.10.1-py2.7.egg/mxnet/module/module.py,这个mxnet-0.10.1-py2.7.egg就是在编译python,即运行setup.py时候生成的。module.py这个脚本中定义了Module类,而这个类继承另一个类BaseModule,BaseModule这个类是在base_module.py脚本中定义的,该脚本和module.py脚本在同一目录下。因为fit()方法继承自BaseModule类并且在BaseModule类中实现,Module类中并没有做修改,所以在BaseModule类中主要看fit()方法。注意到在BaseModule类中其他方法最后都有一句raise NotImplementedError(),这句话是引发未实现的异常的,也就是说如果你继承了BaseModule类却没有实现一些方法(带raise这句话),那么在调用这个方法的时候就会引发异常。

base_module.py代码的git地址:
https://github.com/dmlc/mxnet/blob/master/python/mxnet/module/base_module.py

base_module.py中最重要的就是BaseModule类,接下来截取这个类中重要的几个函数来介绍,大部分函数只是在这里定义,具体实现一般由别的类通过继承该类进行实现:

class BaseModule(object):
"""The base class of a module.

    A module represents a computation component. One can think of module as a computation machine.
    A module can execute forward and backward passes and update parameters in a model.
    We aim to make the APIs easy to use, especially in the case when we need to use the imperative
    API to work with multiple modules (e.g. stochastic depth network).

A module has several states:

    - Initial state: Memory is not allocated yet, so the module is not ready for computation yet.
    - Binded: Shapes for inputs, outputs, and parameters are all known, memory has been allocated,
      and the module is ready for computation.
    - Parameters are initialized: For modules with parameters, doing computation before
      initializing the parameters might result in undefined outputs.
    - Optimizer is installed: An optimizer can be installed to a module. After this, the parameters
      of the module can be updated according to the optimizer after gradients are computed
      (forward-backward).

In order for a module to interact with others, it must be able to report the
    following information in its initial state (before binding):

    - `data_names`: list of type string indicating the names of the required input data.
    - `output_names`: list of type string indicating the names of the required outputs.

    After binding, a module should be able to report the following richer information:

    - state information
        - `binded`: `bool`, indicates whether the memory buffers needed for computation
          have been allocated.
        - `for_training`: whether the module is bound for training.
        - `params_initialized`: `bool`, indicates whether the parameters of this module
          have been initialized.
        - `optimizer_initialized`: `bool`, indicates whether an optimizer is defined
          and initialized.
        - `inputs_need_grad`: `bool`, indicates whether gradients with respect to the
          input data a
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值