前言
pytorch中,变量分为叶子节点和非叶子节点,叶子节点没有grad_fn属性,而非叶子节点是由计算产生的,包含有grad_fn属性,对于叶子节点一般是不允许进行in-place运算的,in-place运算会导致叶子节点变为非叶子节点,这时一般会报错,关于叶子节点的in-place运算问题暂且先不讨论。主要讨论对于非叶子节点进行in-place运算与报错之间的联系。
实例
在pytorch中采取反向传播算法来计算梯度,并且应用求导的链式法则,在进行计算的时候,隐性的构建一个计算图,关于非叶子节点进行in-place运算的时候,会将变量的_version属性加1,当在反向传播的时候遇到同一个变量但是_version属性不一致的时候就会报错,如果在求梯度的时候只用到了一次或者没用过,那么即使进行了in-place运算也不会报错,依然可以求出梯度。
举例如下:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
if __name__ == '__main__':
Q = torch.randn(2, 2, 3, requires_grad=True)
V = torch.randn(2, 3, 3, requires_grad=True)
V1 = V + 1
a = torch.bmm(Q, V1.permute(0, 2, 1))
a /= math.sqrt(3)
V1[:, 0, 1] = 0
a[:, 1, 2:] = -1e6
a = F.softmax(a, dim=-1)
res = torch.bmm(a, V1)
l = res.sum().backward()
print(V.grad)
上述代码,在backward()的时候,会因为进行了in-place运算产生报错,原因在于,对torch.bmm(a,V1)操作求梯度的时候,对a的梯度是含有V1变量的,在torch.bmm(Q,V1)求对Q的梯度的时候,也是含有V1变量的,这两个变量的版本不一致,在这两个表达式之间进行了in-place运算,所以导致了报错,但是如果不对Q求梯度,只对V求梯度,那么就不会出现第二次V1变量,也就不会产生in-place运算的错误。
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
if __name__ == '__main__':
Q = torch.randn(2, 2, 3, requires_grad=False)
V = torch.randn(2, 3, 3, requires_grad=True)
V1 = V + 1
a = torch.bmm(Q, V1.permute(0, 2, 1))
a /= math.sqrt(3)
V1[:, 0, 1] = 0
a[:, 1, 2:] = -1e6
a = F.softmax(a, dim=-1)
res = torch.bmm(a, V1)
l = res.sum().backward()
print(V.grad)
把Q的requires_grad设为False,就可以成功的求出梯度了。