AI梯度传播之链式法则

在神经网络的反向传播过程中,梯度是通过链式法则(Chain Rule)逐层相乘的方式从后向前传递的。这种机制使得前层的参数能够根据最终损失函数的反馈进行调整。下面通过具体示例和数学推导详细说明这一过程。


1. 链式法则的核心思想

梯度反向传播的本质是复合函数的求导。若前向传播为:

        y=f(g(h(x)))

则反向传播时,损失 LL 对输入 xx 的梯度为:

        ∂L/∂x=∂L/∂y*∂y/∂g*∂g/∂h*∂h/∂x

梯度是逐层相乘传递的,每一层的梯度依赖于后一层的梯度计算结果。


2. 以ResNet残差块为例的梯度传播

假设一个简化残差块结构:

        y=F(x)+x,其中 F(x)=W2⋅ReLU(W1⋅x)(忽略BatchNorm和偏置项以便简化)

(1) 前向传播步骤
  1. h1=W1⋅xh

  2. a1=ReLU(h1)

  3. h2=W2⋅a1

  4. y=h2+x

(2) 反向传播过程

假设损失函数 LL 对输出 yy 的梯度为 ∂L/∂y(来自后续层),则:

  1. 梯度传播至 h2和 x(残差连接的分支):

            ∂L/∂h2=∂L/∂y⋅1,        ∂L/∂x=∂L/∂y⋅1

    注:残差连接将梯度直接传递到前层,缓解梯度消失。

  2. 梯度传播至 W2W

    ∂L/∂W2=∂L/∂h2⋅a1T

    (需转置输入激活值 a1a1​)

  3. 梯度传播至 a1​

    ∂L/∂a1=W2T⋅∂L/∂h2
  4. 梯度传播至 h1(经过ReLU):

    ∂L/∂h1=∂L/∂a1⋅II (h1>0)

    (II 为指示函数,ReLU的梯度在输入≤0时为0)

  5. 梯度传播至 W1

    ∂L/∂W1=∂L/∂h1⋅xT
  6. 梯度传播至输入 x(主路径+残差路径):

    ∂L/∂x=W1T⋅∂L/∂h1+∂L/∂y

    这是残差结构的核心:梯度来自两条路径的叠加。


3. 梯度相乘传递的关键点

  1. 链式法则的逐层连乘
    每一层的梯度计算都依赖于后一层的梯度结果,例如:

    ∂L/∂W1=∂L/∂y(后层梯度) * ∂y/∂h2 * ∂h2/∂a1 * ∂a1/∂h1 *∂h1/∂W1(本层局部梯度)
  2. 梯度消失/爆炸的原因

    • 若多层梯度绝对值 ≪1≪1,连乘后前层梯度趋近于0(消失)。

    • 若多层梯度绝对值 ≫1≫1,连乘后前层梯度急剧增大(爆炸)。

    • ResNet的解决方案:残差连接提供恒等映射路径,确保至少有一条路径的梯度为1(∂y/∂x=1),缓解梯度消失。

  3. 局部梯度的依赖关系

    • 卷积层的梯度依赖于输入数据和权重矩阵的转置。

    • 激活函数(如ReLU)的梯度是局部门控(0或1)。


4. 可视化示例

假设一个3层简单网络(无残差):

L→Linear3→ReLU→Linear2→ReLU→Linear1→xL→Linear3​→ReLU→Linear2​→ReLU→Linear1​→x

反向传播时:

∂L/∂W1=∂L/∂Linear3 * ∂Linear3/∂Linear2 * ∂Linear2/ ∂Linear1 * ∂Linear1/ ∂W1

每一层的梯度都是后一层梯度与本层局部梯度的乘积。


5. 总结

  • 梯度传递规则:从后向前逐层连乘,链式法则是核心数学工具。

  • ResNet的改进:残差连接通过梯度叠加(而非连乘)缓解深层网络的梯度消失问题。

  • 实际训练:需注意梯度裁剪(Gradient Clipping)防止爆炸,以及归一化(如BatchNorm)稳定梯度分布。

通过这种机制,无论网络多深,前层参数都能获得有效的梯度更新信号。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值