[Pytorch] torch.nn.Module的hook机制

本文深入探讨PyTorch中Module的hook机制,包括前向和后向hook的注册与执行流程,以及它们在调试和性能分析中的应用。特别关注了hook在不同层级(全局与模块级)的作用及调用顺序。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch\nn\modules\module.py文件里的Module是所有Module的基类。

 

以下3个dict里面放的是作用于所有Module的hook:

_global_backward_hooks = OrderedDict()
_global_forward_pre_hooks = OrderedDict()
_global_forward_hooks = OrderedDict()

原注释:This is global state used for debugging/profiling purposes

通过register_module_forward_pre_hook等3个全局函数进行注册;

 

以下3个dict里面放的是作用于本Module的hook:

self._backward_hooks

self._forward_pre_hooks

self._forward_hooks

通过Module.register_backward_hook等3个Module的成员函数进行注册;

 

调用myModule(input),即调用Module.__call__, 等价于调用_call_impl;里面的执行顺序:

1. 先执行_global_forward_pre_hooks,再执行self._forward_pre_hooks

2. 执行self.forward (执行的是具体子类的forward)

3. 先执行_global_forward_hooks,再执行self._forward_hooks

4. 将上步输出结果var(一个Tensor),对其调用var.grad_fn.register_hook,先把_global_backward_hooks注册进去,再把self._backward_hooks注册进去;

 

grad_fn.register_hook:

grad_fn对应python_variable.cpp里的THPVariable_get_grad_fn;

    Tensor.py里可以看到Tensor是继承自torch._C._TensorBase的;python_variable.cpp里torch._C._TensorBase的THPVariable_properties里面有grad_fn;grad_fn就是THPVariable_get_grad_fn;

THPVariable_get_grad_fn里的self->cdata就是Tensor(C++的);调用Tensor::grad_fn()

Tensor::grad_fn()的返回值是Node类型;

在python_cpp_function.h里,register_hook和THPCppFunction_register_hook是绑定的;

    THPCppFunction_register_hook里面调用的是registerFunctionHook    

在python_function.cpp里,register_hook和THPFunction_register_hook是绑定的;

    THPFunction_register_hook里调用的也是torch::autograd::registerFunctionHook;有一些注释:

        "Legacy autograd function had _register_hook called before the function was "

        "invoked.  This usage pattern is no longer supported: please call _register_hook "

        "AFTER calling your function, or port your code to use non-legacy autograd function, see: "

        "https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd"

registerFunctionHook里对这个Node调用了add_post_hook

Node的hook被谁调用:

function.h里,Node类有add_post_hook和add_pre_hook成员;操作post_hooks_和pre_hooks_私有成员;

post_hooks_和pre_hooks_是通过post_hooks()和pre_hooks()被外界访问的;

engine.cpp里的call_pre_hooks和call_post_hooks分别调用了pre_hooks()和post_hooks();

engine.cpp里,call_function函数先调用call_pre_hooks,再调用具体的fn(...)执行操作,最后调用call_post_hooks;

Engine::thread_main ==> Engine::evaluate_function ==> call_function

tensor.py里,Tensor的register_hook成员函数:

register_hook调用了grad_fn._register_hook_dict

_register_hook_dict:

    也许对应THPCppFunction_register_hook_dict

    也许对应THPFunction__register_hook_dict

以上两个,都调用的是Node的add_pre_hook,而不是add_post_hook,似乎和Tensor.register_hook的注释有些矛盾了?

例子:

a = torch.tensor([2.], requires_grad=True)

b = a * 3

b.register_hook(lambda grad: print(a.grad))

b.backward()

None  (输出)

a.grad

tensor([3.])  (输出)

以上证明了就是backward之前调用的hook;如果是backward之后调用的hook,输出应该是最后哪个;

 

似乎对应torch\csrc\autograd\python_cpp_function.h里的THPCppFunction_register_hook函数(C函数),而注册进去的hook函数似乎是python函数;

在Tensor.cpp里,可以找到Tensor::grad_fn(),其返回值是Node类型对象;因此grad_fn本质是Node(torch\csrc\autograd\function.h);

而Node类里,可以找到add_pre_hook和add_post_hook,因此我认为是可以给backward前后都加上hook的;

 

module.py里,在2中backward hook的注册函数上方,有注释:

        The current implementation will not have the presented behavior
        for complex :class:`Module` that perform many operations.
        In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
        contain the gradients for a subset of the inputs and outputs.
        For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
        directly on a specific input or output to get the required gradients.

各种Loss,也是继承自Module的;所以这些Loss的backward调用,可以适用Module的global的hook;

但是如果用户自己写了个没有继承自Module的loss,那backward上就没有hook了,感觉还得使用Tensor.backward的hook才最保险?

 

疑问:forward的前后都可以设hook,为什么backward只能后面设hook, 前面为什么不设??

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值