Backward()函数
Backward函数实际上是通过传递参数(默认情况下是1x1单位张量)来计算梯度的,它通过Backward图一直到每个叶节点,每个叶节点都可以从调用的根张量追溯到叶节点。然后将计算出的梯度存储在每个叶节点的.grad
中。请记住,在正向传递过程中已经动态生成了后向图。backward函数仅使用已生成的图形计算梯度,并将其存储在叶节点中。
让我们分析以下代码:
import torch
# Creating the graph
x = torch.tensor(1.0, requires_grad = True)
z = x ** 3
z.backward() #Computes the gradient
print(x.grad.data) #Prints '3' which is dz/dx
需要注意的一件重要事情是,当调用z.backward()
时,一个张量会自动传递为z.backward(torch.tensor(1.0))
。torch.tensor(1.0)
是用来终止链式法则梯度乘法的外部梯度。这个外部梯度作为输入传递给MulBackward
函数,以进一步计算x的梯度。传递到.backward()
中的