反向传播中 batch 内梯度更新的具体流程

🎯 任务场景

假设我们训练一个最简单的全连接神经网络来做分类,具体情况如下:

  • 输入维度 x ∈ R 2 (即每个样本是一个二维向量) x \in \mathbb{R}^2 (即每个样本是一个二维向量) xR2(即每个样本是一个二维向量)
  • 输出类别:二分类(用 Sigmoid 激活)
  • 损失函数:均方误差(MSE)
  • batch size:4(简化演示)

🏋️‍♂️ 1. 网络定义

模型的形式是:
y pred = σ ( W x + b ) y_{\text{pred}} = \sigma(Wx + b) ypred=σ(Wx+b)
其中:

  • x 是输入样本 x 是输入样本 x是输入样本
  • W 是 2 维权重矩阵 W 是 2 维权重矩阵 W2维权重矩阵
  • b 是偏置 b 是偏置 b是偏置
  • σ ( z ) = 1 1 + e − z 是 S i g m o i d 激活 \sigma(z) = \frac{1}{1 + e^{-z}} 是 Sigmoid 激活 σ(z)=1+ez1Sigmoid激活

📸 2. 准备一个 batch 的数据

假设 batch size = 4,有 4 个样本:

样本编号输入 $ x_1, x_2 $标签 $ y $
1(0.2, 0.4)1
2(0.1, 0.5)0
3(0.3, 0.7)1
4(0.6, 0.9)0

🚀 3. 前向传播(对 batch 里的每个样本)

参数初始化

W = [ 0.1 0.2 ] , b = 0.05 W = \begin{bmatrix} 0.1 & 0.2 \end{bmatrix}, \quad b = 0.05 W=[0.10.2],b=0.05

① 第一个样本:

输入 x 1 = ( 0.2 , 0.4 ) x_1 = (0.2, 0.4) x1=(0.2,0.4)
z 1 = W x 1 + b = ( 0.1 ) ( 0.2 ) + ( 0.2 ) ( 0.4 ) + 0.05 = 0.15 z_1 = W x_1 + b = (0.1)(0.2) + (0.2)(0.4) + 0.05 = 0.15 z1=Wx1+b=(0.1)(0.2)+(0.2)(0.4)+0.05=0.15
预测值:
y pred 1 = σ ( 0.15 ) = 1 1 + e − 0.15 ≈ 0.537 y_{\text{pred}_1} = \sigma(0.15) = \frac{1}{1 + e^{-0.15}} \approx 0.537 ypred1=σ(0.15)=1+e0.1510.537
损失:
L 1 = 1 2 ( 0.537 − 1 ) 2 ≈ 0.107 L_1 = \frac{1}{2} (0.537 - 1)^2 \approx 0.107 L1=21(0.5371)20.107


② 第二个样本:

输入 x 2 = ( 0.1 , 0.5 ) x_2 = (0.1, 0.5) x2=(0.1,0.5)
z 2 = 0.16 z_2 = 0.16 z2=0.16
预测值:
y pred 2 ≈ 0.540 y_{\text{pred}_2} \approx 0.540 ypred20.540
损失:
L 2 = 1 2 ( 0.540 − 0 ) 2 ≈ 0.146 L_2 = \frac{1}{2} (0.540 - 0)^2 \approx 0.146 L2=21(0.5400)20.146


③ 第三个样本:

输入 x 3 = ( 0.3 , 0.7 ) x_3 = (0.3, 0.7) x3=(0.3,0.7)
z 3 = 0.22 z_3 = 0.22 z3=0.22
预测值:
y pred 3 ≈ 0.555 y_{\text{pred}_3} \approx 0.555 ypred30.555
损失:
L 3 ≈ 0.099 L_3 \approx 0.099 L30.099


④ 第四个样本:

输入 x 4 = ( 0.6 , 0.9 ) x_4 = (0.6, 0.9) x4=(0.6,0.9)
z 4 = 0.29 z_4 = 0.29 z4=0.29
预测值:
y pred 4 ≈ 0.571 y_{\text{pred}_4} \approx 0.571 ypred40.571
损失:
L 4 ≈ 0.163 L_4 \approx 0.163 L40.163


🔄 4. 反向传播(链式法则)

对每个样本,计算损失对参数的梯度:

单个样本的损失对参数的梯度(以第一个样本为例):

损失对输出的梯度:
∂ L 1 ∂ y pred 1 = ( 0.537 − 1 ) = − 0.463 \frac{\partial L_1}{\partial y_{\text{pred}_1}} = (0.537 - 1) = -0.463 ypred1L1=(0.5371)=0.463

输出对激活前的线性变换的梯度:
∂ y pred 1 ∂ z 1 = σ ( 0.15 ) ⋅ ( 1 − σ ( 0.15 ) ) ≈ 0.248 \frac{\partial y_{\text{pred}_1}}{\partial z_1} = \sigma(0.15) \cdot (1 - \sigma(0.15)) \approx 0.248 z1ypred1=σ(0.15)(1σ(0.15))0.248

对权重的梯度(链式法则):
∂ L 1 ∂ W = − 0.463 ⋅ 0.248 ⋅ ( 0.2 , 0.4 ) ≈ ( − 0.023 , − 0.046 ) \frac{\partial L_1}{\partial W} = -0.463 \cdot 0.248 \cdot (0.2, 0.4) \approx (-0.023, -0.046) WL1=0.4630.248(0.2,0.4)(0.023,0.046)


对整个 batch(4 个样本)求平均:

将每个样本的梯度求平均(分别对每个权重):
g W = 1 4 ∑ i = 1 4 ∂ L i ∂ W g_W = \frac{1}{4} \sum_{i=1}^4 \frac{\partial L_i}{\partial W} gW=41i=14WLi

假设最终得到的平均梯度为:
g W = ( − 0.015 , − 0.028 ) g_W = (-0.015, -0.028) gW=(0.015,0.028)


🔧 5. 更新参数

使用 SGD 更新参数:
W = W − η ⋅ g W W = W - \eta \cdot g_W W=WηgW
假设学习率 η = 0.1 \eta = 0.1 η=0.1更新为:
W 1 = 0.1 − 0.1 ⋅ ( − 0.015 ) = 0.1015 W_1 = 0.1 - 0.1 \cdot (-0.015) = 0.1015 W1=0.10.1(0.015)=0.1015
W 2 = 0.2 − 0.1 ⋅ ( − 0.028 ) = 0.2028 W_2 = 0.2 - 0.1 \cdot (-0.028) = 0.2028 W2=0.20.1(0.028)=0.2028


🎉 完整训练流程总结

  1. 前向传播:
    • 每个样本单独通过网络,生成预测值。
    • 计算损失。
  2. 反向传播:
    • 对每个样本独立计算梯度。
    • 对 batch 里的所有样本的梯度取平均。
  3. 更新参数:
    • 用平均梯度更新模型参数。

💡 直观理解

  • 单个样本的梯度会有噪音(因为单个样本可能是异常值)。
  • 通过 batch 取平均,相当于在多个样本之间“求共识”,让参数更新方向更稳健。
  • 如果 batch size 很小,更新会有噪音;如果 batch size 很大,更新会更稳定但可能收敛变慢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值