应用的对象应是nn.Module类型或者它的继承类。
可以有children。
比如
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
net = Model()
这里的net就是一个Module对象,它包含了children: conv1, conv2
fn用来给childer初始化weight,
比如fn可以这样定义:
def weight_init(m):
""" Weight initialization function. """
# Conv2D
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
else:
pass
用的时候Applies fn recursively to every submodule (as returned by .children()) as well as self.
可看到官方源码如下,先对children应用fn, 再对self应用,最后返回self
def apply(self: T, fn: Callable[['Module'], None]) -> T:
for module in self.children():
module.apply(fn)
fn(self)
return self
应用例:
net.apply(weight_init)