torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法
查看源码
初始化部分:
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
属性解释:
_parameters:字典,保存用户直接设置的 Parameter
_modules:子 module,即子类构造函数中的内容
_buffers:缓存
_backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
training:判断值来决定前向传播策略
方法定义:
def forward(self, *input):
raise NotImplementedError
没有实际内容,用于被子类的 forward() 方法覆盖
且 forward 方法在 __call__ 方法中被调用:
def __call__(self,