写在前面:在MXNet中有一个很重要的脚本:base_module.py,这个脚本中的BaseModule类定义了和模型实现相关的框架。另外还有一个脚本module.py会在另外一篇博客中讲,这个类继承BaseModule类,并进行了具体的实现,比如forward方法,update_metric方法等。
在创建模型的时候用的是mx.mod.Module这个类,比如:model = mx.mod.Module(symbol = sym),具体这个Module类的内容是什么?其实mx.mod.Module准确地写应该是mxnet.module.Module(),这个类的路径是~/mxnet/python/mxnet/module/module.py,但是在实际调用的时候是调用python库中的module.py,路径是:/usr/local/lib/python2.7/dist-packages/mxnet-0.10.1-py2.7.egg/mxnet/module/module.py,这个mxnet-0.10.1-py2.7.egg就是在编译python,即运行setup.py时候生成的。module.py这个脚本中定义了Module类,而这个类继承另一个类BaseModule,BaseModule这个类是在base_module.py脚本中定义的,该脚本和module.py脚本在同一目录下。因为fit()方法继承自BaseModule类并且在BaseModule类中实现,Module类中并没有做修改,所以在BaseModule类中主要看fit()方法。注意到在BaseModule类中其他方法最后都有一句raise NotImplementedError(),这句话是引发未实现的异常的,也就是说如果你继承了BaseModule类却没有实现一些方法(带raise这句话),那么在调用这个方法的时候就会引发异常。
base_module.py代码的git地址:
https://github.com/dmlc/mxnet/blob/master/python/mxnet/module/base_module.py
base_module.py中最重要的就是BaseModule类,接下来截取这个类中重要的几个函数来介绍,大部分函数只是在这里定义,具体实现一般由别的类通过继承该类进行实现:
class BaseModule(object):
"""The base class of a module.
A module represents a computation component. One can think of module as a computation machine.
A module can execute forward and backward passes and update parameters in a model.
We aim to make the APIs easy to use, especially in the case when we need to use the imperative
API to work with multiple modules (e.g. stochastic depth network).
A module has several states:
- Initial state: Memory is not allocated yet, so the module is not ready for computation yet.
- Binded: Shapes for inputs, outputs, and parameters are all known, memory has been allocated,
and the module is ready for computation.
- Parameters are initialized: For modules with parameters, doing computation before
initializing the parameters might result in undefined outputs.
- Optimizer is installed: An optimizer can be installed to a module. After this, the parameters
of the module can be updated according to the optimizer after gradients are computed
(forward-backward).
In order for a module to interact with others, it must be able to report the
following information in its initial state (before binding):
- `data_names`: list of type string indicating the names of the required input data.
- `output_names`: list of type string indicating the names of the required outputs.
After binding, a module should be able to report the following richer information:
- state information
- `binded`: `bool`, indicates whether the memory buffers needed for computation
have been allocated.
- `for_training`: whether the module is bound for training.
- `params_initialized`: `bool`, indicates whether the parameters of this module
have been initialized.
- `optimizer_initialized`: `bool`, indicates whether an optimizer is defined
and initialized.
- `inputs_need_grad`: `bool`, indicates whether gradients with respect to the
input data a