hook函数是勾子函数,用于在不改变原始模型结构的情况下,注入一些新的代码用于调试和检验模型,常见的用法有保留非叶子结点的梯度数据(Pytorch的非叶子节点的梯度数据在计算完毕之后就会被删除,访问的时候会显示为
None
),又或者查看模型的层与层之间的数据传递情况(数据维度、数据大小等),抑或是在不修改原始模型代码的基础上可视化各个卷积特征图。
Pytorch提供了四种hook函数
- torch.tensor.register_hook(hooc_func)
- torch.nn.Module.register_forward_hook(hook_func)
- torch.nn.Module.register_forward_pre_hook(hook_func)
- torch.nn.Module.register_backward_hook
1. torch.tensor.register_hook(hooc_func)
解释:注册一个反向传播hook函数,其函数签名如下
def hook(grad):
...
输入参数为张量的梯度,实现的hook函数可以在此修改梯度数据(原地修改或者通过返回值返回),或者在此将梯度数据保存、裁剪等。
示例 1
# leaf node data
x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
# intermediate variable
z = x + y
# output
o = torch.dot(w, z)
# backward to calculate gradient
o.backward()
# print gradient infomation
print('x.grad:', x.grad) # tensor([1., 2., 3., 4.])
print('y.grad:', y.grad) # tensor([1., 2., 3., 4.])
print('w.grad:', w.grad) # tensor([ 4., 6., 8., 10.])
print('z.grad:', z.grad) # None
print('o.grad:', o.grad) # None
输出:
x.grad: tensor([1., 2., 3., 4.])
y.grad: tensor([1., 2., 3., 4.])
w.grad: tensor([ 4., 6., 8., 10.])
z.grad: None
o.grad: None
可以看到代码中的非叶子节点z, o
的梯度信息(grad)在计算之后立即被释放,因此都等于None
,如果需要显式地声明需要保留非叶子节点的grad,需要使用retain_grad
方法,如下例:
import torch
a = torch.ones(5)
a.requires_grad = True
b = 2*a
b.retain_grad() # 让非叶子节点b的梯度保持
c = b.mean()
c.backward()
print(f'a.grad = {
a.grad}\nb.grad = {
b.grad}')
输出:
a.grad = tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])
b.grad = tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
retain_grad()
方法会增加显存的占用,我们可以使用hook
获取梯度信息而不需要显式地使用retain_grad()
强制系统保存梯度信息,如下例:
import torch
a = torch.ones(5).requires_grad_()
b = 2 * a
a.register_hook(lambda x:print(f'a.grad = {
x}'))
b.register_hook(lambda x: print(f'b.grad = {
x}'))
c = b.mean()
print('begin backward'.center(30, '-'))
c.backward()
print('end backward'.center(30, '-'))
输出:
--------begin backward--------
b.grad = tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
a.grad = tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])
---------end backward---------
上述例子中我们使用hook
对tensor
的grad
进行访问,没有使用retain_grad
对信息进行保存。输出结果表明,hook
执行的时间是在backward
之间,从后往前依次执行,首先输出b
的grad
,然后输出a
的grad
,最后结束backward
过程。
上述过程都没有对梯度信息进行改变,其实,如果hook
函数的有返回值或者将输入参数grad
原地进行修改的话,那么之后的梯度信息都会被改变,这一机制简直就是为梯度裁剪量身定制的。
如下例:
import torch
def hook(grad):
torch.clamp_(grad, min=0.5, max=0.2)
print(grad)
a = torch.ones(5).requires_grad_()
b = 2 * a
a.register_hook(hook)
b.register_hook(hook)
c = b.mean()
print('begin backward'.center(30, '-'))
c.backward()
print('end backward'.center(30, '-'))
输出:
--------begin backward--------
tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
---------end backward---------
对比上一例可以发现a
的梯度从0.4
被裁剪到了0.2
,这里使用的clamp_
是直接原地修改,所以不需要返回值。
也可将上述例子中的hook
更改为有返回值的函数,效果相同。
部分例子参考:https://zhuanlan.zhihu.com/p/662760483
2. torch.nn.Module.register_forward_hook(hook_func)
除了register_hook
是对tensor
操作的hook
之外,其他的hook
都是对module
进行操作的,这里的module
包括各种layer
,例如:Conv2d
, Linear
等
register_forward_hook
在执行module
的forward
函数之后执行,其函数签名为
def hook(module, inputs, outpus):
pass
注意:这里的
module
是当前被注册的module
,inputs
是执行forward
之前的inputs
,而outputs
则是执行forward
之后的outputs
,这么设计可能是为了方便读取执行之前的intputs
。
如下例所示:
import torch