1. BatchNorm 的完整前向传播操作
训练模式(Training Mode)的完整流程:
完整一次BatchNorm前向传播的步骤分解:
步骤1:计算统计量
- 均值计算:对当前batch内所有样本按通道计算均值
(对于卷积层:在batch维度、高度维度、宽度维度上计算)
- 方差计算:对当前batch内所有样本按通道计算方差
步骤2:归一化处理
- 中心化:输入数据减去均值
- 标准化:中心化后的数据除以(标准差 + 极小常数ε)
(防止除以零,通常ε=1e-5)
步骤3:缩放和平移(可学习参数变换)
- 缩放:乘以可学习参数γ(gamma)
- 平移:加上可学习参数β(beta)
- 这一步让网络可以学习恢复或改变数据分布
步骤4:更新运行统计量(仅训练时)
- 采用移动平均方式更新全局统计量
- running_mean = 动量×原running_mean + (1-动量)×当前mean
- running_var = 动量×原running_var + (1-动量)×当前var
- 动量通常设为0.9
步骤5:保存缓存(用于反向传播)
- 保存:输入数据x、均值mean、方差var、归一化结果x_norm、ε值
- 这些缓存将在反向传播时使用
可视化操作流程:
┌─────────────────────────────────────────────┐
│ 输入 x (N, C, H, W) │
├─────────────────────────────────────────────┤
│ 步骤1:计算统计量 │
│ - mean = 对batch内所有样本计算均值 │
│ (按通道,跨N,H,W维度) │
│ - var = 对batch内所有样本计算方差 │
├─────────────────────────────────────────────┤
│ 步骤2:归一化 │
│ - x_centered = x - mean │
│ - x_norm = x_centered / sqrt(var + ε) │
├─────────────────────────────────────────────┤
│ 步骤3:缩放和平移(可学习参数) │
│ - y = γ * x_norm + β │
├─────────────────────────────────────────────┤
│ 步骤4:更新运行统计量(训练时) │
│ - running_mean = m*running_mean + (1-m)*mean│
│ - running_var = m*running_var + (1-m)*var │
├─────────────────────────────────────────────┤
│ 步骤5:保存缓存(用于反向传播) │
│ - cache = {x, mean, var, x_norm, γ, β, ε} │
└─────────────────────────────────────────────┘
2. 完整反向传播操作
反向传播时BatchNorm的计算步骤:
步骤1:计算γ和β的梯度
- dγ = Σ(上一层梯度 × 归一化后的值)
- dβ = Σ(上一层梯度)
步骤2:计算归一化结果的梯度
- dx_norm = 上一层梯度 × γ
步骤3:计算方差的梯度
- 因为方差出现在归一化的分母中
- 通过链式法则计算方差的梯度贡献
步骤4:计算均值的梯度
- 均值影响两项:直接相减项和方差计算项
- 需要综合这两部分的影响
步骤5:计算输入x的最终梯度
- x的影响通过三条路径传播:
1. 直接通过归一化项
2. 通过方差项
3. 通过均值项
- 将三条路径的梯度相加得到最终梯度
步骤6:传递梯度并更新参数
- 将梯度dx传递给前一层
- 用dγ和dβ更新γ和β参数
3. 推理模式(Inference Mode)的完整操作
推理阶段的简化流程:
步骤1:使用运行统计量进行归一化 - 不再计算当前batch的统计量 - 直接使用训练过程中积累的running_mean和running_var - 归一化公式:x_norm = (x - running_mean) / sqrt(running_var + ε) 步骤2:应用训练好的变换参数 - 使用训练好的γ和β参数(不再更新) - 输出:y = γ * x_norm + β 步骤3:直接输出结果 - 无需保存缓存 - 无需更新统计量 - 计算速度更快,结果确定
4. BatchNorm的核心组件
一个完整BatchNorm层包含的四个核心部分:
1. 统计计算模块 - 训练时:动态计算batch统计量 - 推理时:使用固定统计量 2. 归一化变换模块 - 标准化:使数据零均值、单位方差 - 可学习的逆变换:γ缩放 + β平移 3. 统计维护模块 - 维护running_mean和running_var - 采用移动平均策略更新 4. 模式切换机制 - 训练模式:完整流程,更新统计 - 推理模式:简化流程,使用固定统计
5. 一次完整BatchNorm操作的时间线
训练阶段一次完整迭代的时间流程:
前向传播阶段: 时间点t0-t1: 接收输入数据,计算当前batch的均值和方差 时间点t2: 执行归一化:减去均值,除以标准差 时间点t3: 执行可学习变换:γ缩放 + β平移 时间点t4: 输出变换后的数据给下一层 时间点t5: 保存所有中间结果到缓存 时间点t6: 更新全局统计量(running_mean, running_var) 反向传播阶段: 时间点t7: 接收来自上一层的梯度 时间点t8-t9: 计算γ和β的梯度 时间点t10: 计算输入x的梯度(复杂链式求导) 时间点t11: 将梯度传递给前一层 时间点t12: 使用梯度更新γ和β参数
推理阶段一次完整操作的时间流程:
时间点t0: 接收输入数据 时间点t1: 使用预训练的running_mean和running_var进行归一化 时间点t2: 使用预训练的γ和β进行缩放平移 时间点t3: 直接输出结果 时间点t4: 完成,无需任何额外操作
6. BatchNorm的数学本质
训练阶段数学表达:
μ_B = (1/m) Σ x_i # 计算batch均值 σ²_B = (1/m) Σ (x_i - μ_B)² # 计算batch方差 x̂_i = (x_i - μ_B) / √(σ²_B + ε) # 归一化 y_i = γ x̂_i + β # 缩放平移
推理阶段数学表达:
x̂_i = (x_i - μ_running) / √(σ²_running + ε) y_i = γ x̂_i + β
关键区别:
-
训练:μ_B和σ²_B来自当前batch,会更新running统计
-
推理:使用固定的μ_running和σ²_running,不更新
7. 实际应用中的注意事项
BatchNorm的正确使用方式:
1. 位置选择: - 放在卷积层或全连接层之后 - 放在激活函数之前 - 每个需要稳定训练的层后都可能需要 2. 参数初始化: - γ初始化为1(保持初始分布不变) - β初始化为0(初始不偏移) - running_mean初始化为0 - running_var初始化为1 3. 训练技巧: - batch size不能太小(影响统计量可靠性) - 推理前需要足够训练让统计量稳定 - 注意训练和推理的模式切换 4. 特殊场景处理: - 迁移学习时可能需要调整统计量 - 小批量训练时需谨慎使用 - 序列数据可能使用LayerNorm更合适
总结:完整BatchNorm的本质
BatchNorm不是单一操作,而是一个完整的数据处理流程:
输入 → [统计计算] → [归一化] → [可学习变换] → 输出
↓ ↑ ↓
[统计更新] [参数学习] [缓存保存]
↓ ↑ ↓
[全局统计维护] ←--- [反向传播] ←--- [梯度计算]
这个流程确保了:
-
训练稳定性:通过归一化防止内部协变量偏移
-
灵活性:通过γ和β让网络学习最佳数据分布
-
一致性:训练和推理使用相同的数据处理逻辑
-
效率:推理时简化计算,提高速度
这就是一次完整BatchNorm操作的全部内涵,它是一个精心设计的、包含多个协同工作的子模块的完整系统。
1018

被折叠的 条评论
为什么被折叠?



