在看pytorch官方文档的时候,发现在nn.Module部分和Variable部分均有hook的身影。感到很神奇,因为在使用tensorflow的时候没有碰到过这个词。所以打算一探究竟。
Variable 的 hook
register_hook(hook)
注册一个backward钩子。
每次gradients被计算的时候,这个hook都被调用。hook应该拥有以下签名:
hook(grad) -> Variable or None
hook
(
grad
)
->
Variable
or
None
hook不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度。
这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。
例子:
v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) h = v.register_hook(lambda grad: grad * 2) # double the gradient v.backward(torch.Tensor([1, 1, 1])) #先计算原始梯度,再进hook,获得一个新梯度。 print(v.grad.data) h.remove() # removes the hook
v
=
Variable
(
torch
.
Tensor
(
[
0
,
0
,
0
]
)
,
requires_grad
=
True
)
h
=
v
.
register_hook
(
lambda
grad
:
grad
*
2
)
# double the gradient
v
.
backward
(
torch
.
Tensor
(
[
1
,
1
,
1
]
)
)
#先计算原始梯度,再进hook,获得一个新梯度。
print
(
v
.
grad
.
data
)
h
.
remove
(
)
# removes the hook
输出:
2 2 2 [torch.FloatTensor of size 3]
2
2
2
[
torch
.
FloatTensor
of
size
3
]
nn.Module的hook
register_forward_hook(hook)
在module上注册一个forward hook。
这里要注意的是,hook 只能注册到 Module 上,即,仅仅是简单的 op 包装的 Module,而不是我们继承 Module时写的那个类,我们继承 Module写的类叫做 Container。
每次调用forward()计算输出的时候,这个hook就会被调用。它应该拥有以下签名:
hook(module, input, output) -> None
hook
(
module
,
input
,
output
)
->
None
hook不应该修改 input和output的值。 这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。
看这个解释可能有点蒙逼,但是如果要看一下nn.Module的源码怎么使用hook的话,那就乌云尽散了。
先看 register_forward_hook
def register_forward_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle</textarea></div><div class="crayon-main" style="position: relative; z-index: 1; overflow: hidden;"><table class="crayon-table" style=""><tbody><tr class="crayon-row"><td class="crayon-nums " data-settings="show"><div class="crayon-nums-content" style="font-size: 13px !important; line-height: 18px !important;"><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-1">1</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-2">2</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-3">3</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-4">4</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-5">5</div></div></td><td class="crayon-code"><div class="crayon-pre" style="font-size: 13px !important; line-height: 18px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;"><div class="crayon-line" id="crayon-5dd4e6587a059500983678-1"><span class="crayon-r">def</span><span class="crayon-h"> </span><span class="crayon-e">register_forward_hook</span><span class="crayon-sy">(</span><span class="crayon-r">self</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-v">hook</span><span class="crayon-sy">)</span><span class="crayon-o">:</span></div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-2"> </div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-3"><span class="crayon-h"> </span><span class="crayon-v">handle</span><span class="crayon-h"> </span><span class="crayon-o">=</span><span class="crayon-h"> </span><span class="crayon-v">hooks</span><span class="crayon-sy">.</span><span class="crayon-e">RemovableHandle</span><span class="crayon-sy">(</span><span class="crayon-r">self</span><span class="crayon-sy">.</span><span class="crayon-v">_forward_hooks</span><span class="crayon-sy">)</span></div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-4"><span class="crayon-h"> </span><span class="crayon-r">self</span><span class="crayon-sy">.</span><span class="crayon-v">_forward_hooks</span><span class="crayon-sy">[</span><span class="crayon-v">handle</span><span class="crayon-sy">.</span><span class="crayon-k ">id</span><span class="crayon-sy">]</span><span class="crayon-h"> </span><span class="crayon-o">=</span><span class="crayon-h"> </span><span class="crayon-e">hook</span></div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-5"><span class="crayon-e"> </span><span class="crayon-st">return</span><span class="crayon-h"> </span><span class="crayon-v">handle</span></div></div></td></tr></tbody></table></div></div><p>这个方法的作用是在此module上注册一个hook,函数中第一句就没必要在意了,主要看第二句,是把注册的hook保存在_forward_hooks字典里。</p><p>再看 nn.Module 的__call__方法(被阉割了,只留下需要关注的部分):</p><div id="crayon-5dd4e6587a05a173089609" class="crayon-syntax crayon-theme-github crayon-font-monaco crayon-os-pc print-yes notranslate" data-settings=" minimize scroll-mouseover" style="margin-top: 15px; margin-bottom: 15px; font-size: 13px !important; line-height: 18px !important; height: auto;"><div class="crayon-plain-wrap"><textarea wrap="soft" class="crayon-plain print-no" data-settings="dblclick" readonly="" style="tab-size: 4; font-size: 13px !important; line-height: 18px !important; z-index: 0; opacity: 0; overflow: hidden;">def __call__(self, *input, **kwargs):
result = self.forward(input, **kwargs) for hook in self._forward_hooks.values(): #将注册的hook拿出来用 hook_result = hook(self, input, result) … return result
def
call
(
self
,
input
,
kwargs
)
:
result
=
self
.
forward
(
input
,
*
kwargs
)
for
hook
in
self
.
_forward_hooks
.
values
(
)
:
#将注册的hook拿出来用
hook_result
=
hook
(
self
,
input
,
result
)
.
.
.
return
result
可以看到,当我们执行model(x)的时候,底层干了以下几件事:
调用 forward 方法计算结果 判断有没有注册 forward_hook,有的话,就将 forward 的输入及结果作为hook的实参。然后让hook自己干一些不可告人的事情。
看到这,我们就明白hook签名的意思了,还有为什么hook不能修改input的output的原因。
小例子:
import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable
def for_hook(module, input, output): print(module) for val in input: print(“input val:”,val) for out_val in output: print(“output val:”, out_val)
class Model(nn.Module): def init (self): super(Model, self).init () def forward(self, x):
return x+1
model = Model() x = Variable(torch.FloatTensor([1]), requires_grad=True) handle = model.register_forward_hook(for_hook) print(model(x)) handle.remove()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import
torch
from
torch
import
nn
import
torch
.
functional
as
F
from
torch
.
autograd
import
Variable
def
for_hook
(
module
,
input
,
output
)
:
print
(
module
)
for
val
in
input
:
print
(
“input val:”
,
val
)
for
out_val
in
output
:
print
(
“output val:”
,
out_val
)
class
Model
(
nn
.
Module
)
:
def
init
(
self
)
:
super
(
Model
,
self
)
.
init
(
)
def
forward
(
self
,
x
)
:
return
x
+
1
model
=
Model
(
)
x
=
Variable
(
torch
.
FloatTensor
(
[
1
]
)
,
requires_grad
=
True
)
handle
=
model
.
register_forward_hook
(
for_hook
)
print
(
model
(
x
)
)
handle
.
remove
(
)
register_backward_hook
在module上注册一个bachward hook。此方法目前只能用在Module上,不能用在Container上,当Module的forward函数中只有一个Function的时候,称为Module,如果Module包含其它Module,称之为Container。
每次计算module的inputs的梯度的时候,这个hook会被调用。hook应该拥有下面的signature。
hook(module, grad_input, grad_output) -> Tensor or None
hook
(
module
,
grad_input
,
grad_output
)
->
Tensor
or
None
如果module有多个输入输出的话,那么grad_input grad_output将会是个tuple。
hook不应该修改它的arguments,但是它可以选择性的返回关于输入的梯度,这个返回的梯度在后续的计算中会替代grad_input。
这个函数返回一个句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。
从上边描述来看,backward hook似乎可以帮助我们处理一下计算完的梯度。看下面nn.Module中register_backward_hook方法的实现,和register_forward_hook方法的实现几乎一样,都是用字典把注册的hook保存起来。
def register_backward_hook(self, hook):
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def
register_backward_hook
(
self
,
hook
)
:
handle
=
hooks
.
RemovableHandle
(
self
.
_backward_hooks
)
self
.
_backward_hooks
[
handle
.
id
]
=
hook
return
handle
先看个例子来看一下hook的参数代表了什么:
import torch
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn as nn
import math
def bh(m,gi,go):
print(“Grad Input”)
print(gi)
print(“Grad Output”)
print(go)
return gi[0]*0,gi[1]*0
class Linear(nn.Module):
def
init (self, in_features, out_features, bias=True):
super(Linear, self).
init ()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter(‘bias’, None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
if self.bias is None:
return self._backend.Linear()(input, self.weight)
else:
return self._backend.Linear()(input, self.weight, self.bias)
x=Variable(torch.FloatTensor([[1, 2, 3]]),requires_grad=True) mod=Linear(3, 1, bias=False) mod.register_backward_hook(bh) # 在这里给module注册了backward hook
out=mod(x) out.register_hook(lambda grad: 0.1grad) #在这里给variable注册了 hook out.backward() print([’ ’]20) print(“x.grad”, x.grad) print(mod.weight.grad)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import
torch
from
torch
.
autograd
import
Variable
from
torch
.
nn
import
Parameter
import
torch
.
nn
as
nn
import
math
def
bh
(
m
,
gi
,
go
)
:
print
(
“Grad Input”
)
print
(
gi
)
print
(
“Grad Output”
)
print
(
go
)
return
gi
[
0
]
0
,
gi
[
1
]
0
class
Linear
(
nn
.
Module
)
:
def
init
(
self
,
in_features
,
out_features
,
bias
=
True
)
:
super
(
Linear
,
self
)
.
init
(
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
weight
=
Parameter
(
torch
.
Tensor
(
out_features
,
in_features
)
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
out_features
)
)
else
:
self
.
register_parameter
(
‘bias’
,
None
)
self
.
reset_parameters
(
)
def
reset_parameters
(
self
)
:
stdv
=
1.
/
math
.
sqrt
(
self
.
weight
.
size
(
1
)
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
if
self
.
bias
is
not
None
:
self
.
bias
.
data
.
uniform_
(
-
stdv
,
stdv
)
def
forward
(
self
,
input
)
:
if
self
.
bias
is
None
:
return
self
.
_backend
.
Linear
(
)
(
input
,
self
.
weight
)
else
:
return
self
.
_backend
.
Linear
(
)
(
input
,
self
.
weight
,
self
.
bias
)
x
=
Variable
(
torch
.
FloatTensor
(
[
[
1
,
2
,
3
]
]
)
,
requires_grad
=
True
)
mod
=
Linear
(
3
,
1
,
bias
=
False
)
mod
.
register_backward_hook
(
bh
)
# 在这里给module注册了backward hook
out
=
mod
(
x
)
out
.
register_hook
(
lambda
grad
:
0.1
grad
)
#在这里给variable注册了 hook
out
.
backward
(
)
print
(
[
’’
]
20
)
print
(
“x.grad”
,
x
.
grad
)
print
(
mod
.
weight
.
grad
)
输出:
Grad Input
(Variable containing:
1.00000e-02 *
5.1902 -2.3778 -4.4071
[torch.FloatTensor of size 1x3]
, Variable containing:
0.1000 0.2000 0.3000
[torch.FloatTensor of size 1x3]
)
Grad Output
(Variable containing:
0.1000
[torch.FloatTensor of size 1x1]
,)
[’
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’, ‘
’, ' ’]
x.grad Variable containing:
0 -0 -0
[torch.FloatTensor of size 1x3]
Variable containing: 0 0 0 [torch.FloatTensor of size 1x3]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Grad
Input
(
Variable
containing
:
1.00000e
-
02
5.1902
-
2.3778
-
4.4071
[
torch
.
FloatTensor
of
size
1x3
]
,
Variable
containing
:
0.1000
0.2000
0.3000
[
torch
.
FloatTensor
of
size
1x3
]
)
Grad
Output
(
Variable
containing
:
0.1000
[
torch
.
FloatTensor
of
size
1x1
]
,
)
[
’’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
‘’
,
'’
,
'’
]
x
.
grad
Variable
containing
:
0
-
0
-
0
[
torch
.
FloatTensor
of
size
1x3
]
Variable
containing
:
0
0
0
[
torch
.
FloatTensor
of
size
1x3
]
可以看出,grad_in保存的是,此模块Function方法的输入的值的梯度。grad_out保存的是,此模块forward方法返回值的梯度。我们不能在grad_in上直接修改,但是我们可以返回一个新的new_grad_in作为Function方法inputs的梯度。
上述代码对variable和module同时注册了backward hook,这里要注意的是,无论是module hook还是variable hook,最终还是注册到Function上的。这点通过查看Varible的register_hook源码和Module的__call__源码得知。
Module的register_backward_hook的行为在未来的几个版本可能会改变
BP过程中Function中的动作可能是这样的:
class Function:
def
init (self):
…
def forward(self, inputs):
…
return outputs
def backward(self, grad_outs):
…
return grad_ins
def _backward(self, grad_outs):
hooked_grad_outs = grad_outs
for hook in hook_in_outputs:
hooked_grad_outs = hook(hooked_grad_outs)
grad_ins = self.backward(hooked_grad_outs)
hooked_grad_ins = grad_ins
for hook in hooks_in_module:
hooked_grad_ins = hook(hooked_grad_ins)
return hooked_grad_ins
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class
Function
:
def
init
(
self
)
:
.
.
.
def
forward
(
self
,
inputs
)
:
.
.
.
return
outputs
def
backward
(
self
,
grad_outs
)
:
.
.
.
return
grad_ins
def
_backward
(
self
,
grad_outs
)
:
hooked_grad_outs
=
grad_outs
for
hook
in
hook_in_outputs
:
hooked_grad_outs
=
hook
(
hooked_grad_outs
)
grad_ins
=
self
.
backward
(
hooked_grad_outs
)
hooked_grad_ins
=
grad_ins
for
hook
in
hooks_in_module
:
hooked_grad_ins
=
hook
(
hooked_grad_ins
)
return
hooked_grad_ins
关于pytorch run_backward()的可能实现猜测为:
def run_backward(variable, gradient):
creator = variable.creator
if creator is None:
variable.grad = variable.hook(gradient)
return
grad_ins = creator._backward(gradient)
vars = creator.saved_variables
for var, grad in zip(vars, grad_ins):
run_backward(var, var.grad)
def
run_backward
(
variable
,
gradient
)
:
creator
=
variable
.
creator
if
creator
is
None
:
variable
.
grad
=
variable
.
hook
(
gradient
)
return
grad_ins
=
creator
.
_backward
(
gradient
)
vars
=
creator
.
saved_variables
for
var
,
grad
in
zip
(
vars
,
grad_ins
)
:
run_backward
(
var
,
var
.
grad
)
中间Variable的梯度在BP的过程中是保存到GradBuffer中的(C++源码中可以看到), BP完会释放. 如果retain_grads=True的话,就不会被释放。
文章来源:Keith