PyTorch自动微分机制解析:从计算图到梯度反向传播
计算图:动态计算的核心
PyTorch自动微分的核心是动态计算图,它是一种在运行时构建的有向无环图,用于记录张量间的操作历史。每当对张量执行操作时,如加法或矩阵乘法,PyTorch会将该操作作为一个节点加入到计算图中。计算图的叶子节点是输入的张量,根节点是输出的张量。这种动态特性使得图的结构可以根据每次前向传播的具体路径灵活变化,为复杂的控制流(如循环和条件语句)提供了自然的支持。通过计算图,系统能够精确追踪每一个数据是如何通过一系列操作变换而来的。
张量与梯度函数
在PyTorch中,`torch.Tensor`对象是构建计算图的基本单元。当张量的`requires_grad`属性设置为`True`时,PyTorch开始追踪在其上执行的所有操作。每个张量都有一个`.grad_fn`属性,该属性指向创建此张量的`Function`节点。这个`Function`节点是计算图中的核心组件,它不仅记录了生成该张量的操作,还包含了执行反向传播所需的逻辑。例如,一个由加法操作产生的张量,其`.grad_fn`将指向一个`AddBackward`对象,这个对象知道如何计算加法操作对输入张量的导数。
反向传播与梯度计算
当在前向传播中计算出损失(一个标量张量)后,调用损失张量的`.backward()`方法即可触发反向传播过程。此方法会从该损失张量开始,沿着计算图反向遍历每个节点。对于每个`Function`节点,它会调用其类中定义的`.backward()`方法,该方法根据链式法则,利用上游传递来的梯度(即损失对该节点输出的导数),计算该操作对其输入张量的局部梯度(即输出对输入的导数),并将这些局部梯度与上游梯度相乘,得到损失对该节点输入的梯度。这些梯度会被累加到相应输入张量的`.grad`属性中。
梯度累积与`zero_grad`的重要性
默认情况下,PyTorch在调用`.backward()`时会将计算出的梯度累加到张量的`.grad`属性上,而不是覆盖它。这种设计在某些高级场景(如RNN的BPTT)中是有用的,但对于标准的梯度下降优化,这会导致梯度在多次迭代间不断累积,从而产生错误的参数更新。因此,在每一次参数更新步骤之前,必须调用优化器的`.zero_grad()`方法(或手动将模型参数的`.grad`属性置零),以确保每次计算的都是当前批次的正确梯度,避免历史梯度的干扰。
`detach`与`no_grad`:控制梯度追踪
在某些情况下,我们可能希望从计算图中分离部分计算,以避免不必要的梯度计算或内存消耗。`tensor.detach()`方法会返回一个新的张量,它与原张量共享数据存储,但将其从创建它的计算历史中分离,新的张量`.requires_grad`为`False`且`.grad_fn`为`None`,从而阻止梯度向其继续反向传播。而`torch.no_grad()`上下文管理器则提供了一个更便捷的方式,在该上下文中进行的所有操作都不会被记录到计算图中。这在模型评估或仅需前向计算时非常有用,可以显著减少内存消耗并加速计算。
自定义自动微分函数
PyTorch允许用户通过继承`torch.autograd.Function`类来定义自定义的反向传播规则。需要重写两个静态方法:`forward(ctx, inputs)`执行前向计算,并可使用`ctx.save_for_backward`保存反向传播所需的数据;`backward(ctx, grad_outputs)`则接收上游梯度,根据前向传播保存的中间结果,计算并返回对每个输入的局部梯度。这为实现非标准或更高效的操作提供了灵活性,是扩展PyTorch功能的重要手段。
542

被折叠的 条评论
为什么被折叠?



