Parameters
class torch.nn.Parameter()
Variable的一种,常被用于模块参数(module parameter)。
Parameters 是 Variable 的子类。Paramenters和Modules一起使用的时候会有一些特殊的属性,即:当Paramenters赋值给Module的属性的时候,他会自动的被加到 Module的 参数列表中(即:会出现在 parameters() 迭代器中)。将Varibale赋值给Module属性则不会有这样的影响。 这样做的原因是:我们有时候会需要缓存一些临时的状态(state), 比如:模型中RNN的最后一个隐状态。如果没有Parameter这个类的话,那么这些临时变量也会注册成为模型变量。
Variable 与 Parameter的另一个不同之处在于,Parameter不能被 volatile(即:无法设置volatile=True)而且默认requires_grad=True。Variable默认requires_grad=False。
Containers
class torch.nn.Module
add_module(name, module)
将一个 child module 添加到当前 modle。 被添加的module可以通过 name属性来获取。 例:
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.add_module("conv", nn.Conv2d(10, 20, 4))
#self.conv = nn.Conv2d(10, 20, 4) is the same as above
model = Model()
print(model.conv)
输出:
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
children()
Returns an iterator over immediate children modules. 返回当前模型 子模块的迭代器。
import torch.nn as nn
class Model(nn.Module):
def __init__