在神经网络的反向传播过程中,梯度是通过链式法则(Chain Rule)逐层相乘的方式从后向前传递的。这种机制使得前层的参数能够根据最终损失函数的反馈进行调整。下面通过具体示例和数学推导详细说明这一过程。
1. 链式法则的核心思想
梯度反向传播的本质是复合函数的求导。若前向传播为:
y=f(g(h(x)))
则反向传播时,损失 LL 对输入 xx 的梯度为:
∂L/∂x=∂L/∂y*∂y/∂g*∂g/∂h*∂h/∂x
梯度是逐层相乘传递的,每一层的梯度依赖于后一层的梯度计算结果。
2. 以ResNet残差块为例的梯度传播
假设一个简化残差块结构:
y=F(x)+x,其中 F(x)=W2⋅ReLU(W1⋅x)(忽略BatchNorm和偏置项以便简化)
(1) 前向传播步骤
-
h1=W1⋅xh
-
a1=ReLU(h1)
-
h2=W2⋅a1
-
y=h2+x
(2) 反向传播过程
假设损失函数 LL 对输出 yy 的梯度为 ∂L/∂y(来自后续层),则:
-
梯度传播至 h2和 x(残差连接的分支):
∂L/∂h2=∂L/∂y⋅1, ∂L/∂x=∂L/∂y⋅1注:残差连接将梯度直接传递到前层,缓解梯度消失。
-
梯度传播至 W2W:
∂L/∂W2=∂L/∂h2⋅a1T(需转置输入激活值 a1a1)
-
梯度传播至 a1:
∂L/∂a1=W2T⋅∂L/∂h2 -
梯度传播至 h1(经过ReLU):
∂L/∂h1=∂L/∂a1⋅II (h1>0)(II 为指示函数,ReLU的梯度在输入≤0时为0)
-
梯度传播至 W1:
∂L/∂W1=∂L/∂h1⋅xT -
梯度传播至输入 x(主路径+残差路径):
∂L/∂x=W1T⋅∂L/∂h1+∂L/∂y这是残差结构的核心:梯度来自两条路径的叠加。
3. 梯度相乘传递的关键点
-
链式法则的逐层连乘:
∂L/∂W1=∂L/∂y(后层梯度) * ∂y/∂h2 * ∂h2/∂a1 * ∂a1/∂h1 *∂h1/∂W1(本层局部梯度)
每一层的梯度计算都依赖于后一层的梯度结果,例如: -
梯度消失/爆炸的原因:
-
若多层梯度绝对值 ≪1≪1,连乘后前层梯度趋近于0(消失)。
-
若多层梯度绝对值 ≫1≫1,连乘后前层梯度急剧增大(爆炸)。
-
ResNet的解决方案:残差连接提供恒等映射路径,确保至少有一条路径的梯度为1(∂y/∂x=1),缓解梯度消失。
-
-
局部梯度的依赖关系:
-
卷积层的梯度依赖于输入数据和权重矩阵的转置。
-
激活函数(如ReLU)的梯度是局部门控(0或1)。
-
4. 可视化示例
假设一个3层简单网络(无残差):
L→Linear3→ReLU→Linear2→ReLU→Linear1→xL→Linear3→ReLU→Linear2→ReLU→Linear1→x
反向传播时:
∂L/∂W1=∂L/∂Linear3 * ∂Linear3/∂Linear2 * ∂Linear2/ ∂Linear1 * ∂Linear1/ ∂W1
每一层的梯度都是后一层梯度与本层局部梯度的乘积。
5. 总结
-
梯度传递规则:从后向前逐层连乘,链式法则是核心数学工具。
-
ResNet的改进:残差连接通过梯度叠加(而非连乘)缓解深层网络的梯度消失问题。
-
实际训练:需注意梯度裁剪(Gradient Clipping)防止爆炸,以及归一化(如BatchNorm)稳定梯度分布。
通过这种机制,无论网络多深,前层参数都能获得有效的梯度更新信号。