torch.nn.Module
Module类是用户使用torch来自定义网络模型的基础,Module的设计要求包括低耦合性,高模块化等等。一般来说,计算图上所有的子图都可以是Module的子类,包括卷积,激活函数,损失函数节点以及相邻节点组成的集合等等,注意这里的关键词是“节点”,Module族类在计算图中主要起到搭建结构的作用,而不涉及运算逻辑的具体实现。
要注意的是,Module类对象的children所指向的其他Module类对象,并不等同于计算图中的子节点。如果我们展开Module网络,得到的一般是树形结构而非DAG,Module网络需要经过其他工作才能转化为计算图。
源代码分析
成员分析
首先直接从前端入手,找到torch/nn/module目录,可以看到这个目录下主要存放Module及其子类的定义,如。我们首先找到module.py内Module的定义
阅读__init__ 函数,可以看到Module基类的主要私有成员,其中包括
指向本Module内带梯度的可学习参数的parameter
指向本Module内不需要学习的模型状态参数的buffer
其他临时参数
前向与反向过程的hook函数,这些函数在运行backward与forward时允许自定义其它额外工作
state_dict相关函数,state_dict保存了模型的状态,是模型写入磁盘与加载的主要方式
modules指向该模块内部的所有子模