# example
# 定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
# hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字
# fea_in当前forward 的输入,fea_out当前forward的输出
def hook(module, fea_in, fea_out): # hook,自己定义钩子函数的名字
print("hooker working")
module_name.append(module.__class__)
features_in_hook.append(fea_in)
features_out_hook.append(fea_out)
return None
# children()与modules()都是返回网络模型里的组成元素,
# 但是children()返回的是最外层的元素,modules()返回的是所有的元素,包括不同级别的子元素。
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
if not isinstance(child, nn.ReLU6):
# register_forward_hook, pytorch 提供钩子注册函数
child.register_forward_hook(hook=hook)