Batch Normalization 原理
Batch Normalization(BN)的核心思想是通过对神经网络每一层的输入进行标准化处理(减均值、除标准差),解决深层网络训练时的内部协变量偏移(Internal Covariate Shift)问题。具体原理如下:
-
标准化:对每一层的输入数据按特征维度进行标准化,使其均值为0、方差为1。公式为: [ \hat{x}^{(k)} = \frac{x^{(k)} - \mu^{(k)}}{\sqrt{(\sigma^{(k)})^2 + \epsilon}} ] 其中,( \mu^{(k)} ) 和 ( \sigma^{(k)} ) 为当前batch中第( k )个特征的均值和标准差,( \epsilon )为极小值防止除零。
-
可学习参数:标准化后引入可学习的缩放参数( \gamma )和平移参数( \beta ),恢复数据的表达能力: [ y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)} ]
-
推理阶段:训练时统计的均值和方差通过移动平均保存,推理时直接使用全局统计量而非batch统计量。
Batch Normalization 代码实现(PyTorch)
以下是一个完整的BN层实现示例,包含训练和推理模式的支持:
import torch
import torch.nn as nn
class BatchNorm1d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.eps = eps
self.momentum = momentum
# 可学习参数
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# 推理时使用的全局统计量
self.running_mean = torch.zeros(num_features)
self.running_var = torch.ones(num_features)
def forward(self, x):
if self.training:
# 训练模式:使用当前batch的统计量
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
# 更新全局统计量(移动平均)
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
# 推理模式:使用预存的全局统计量
mean = self.running_mean
var = self.running_var
# 标准化计算
x_hat = (x - mean) / torch.sqrt(var + self.eps)
y = self.gamma * x_hat + self.beta
return y
使用示例
# 初始化BN层(假设输入特征维度为5)
bn = BatchNorm1d(5)
# 训练模式
bn.train()
x = torch.randn(32, 5) # batch_size=32
y = bn(x)
# 推理模式
bn.eval()
y_inference = bn(x)
注意事项
- PyTorch原生实现:实际使用时可直接使用
torch.nn.BatchNorm1d等官方接口,上述代码仅为教学目的。 - 与Dropout的交互:BN层通常放在激活函数之前,且在Dropout层之后使用。
- batch_size影响:小batch可能导致统计量估计不准,可尝试
GroupNorm或LayerNorm替代。
1386

被折叠的 条评论
为什么被折叠?



