Batch Normalization原理+代码实现

部署运行你感兴趣的模型镜像

Batch Normalization 原理

Batch Normalization(BN)的核心思想是通过对神经网络每一层的输入进行标准化处理(减均值、除标准差),解决深层网络训练时的内部协变量偏移(Internal Covariate Shift)问题。具体原理如下:

  • 标准化:对每一层的输入数据按特征维度进行标准化,使其均值为0、方差为1。公式为: [ \hat{x}^{(k)} = \frac{x^{(k)} - \mu^{(k)}}{\sqrt{(\sigma^{(k)})^2 + \epsilon}} ] 其中,( \mu^{(k)} ) 和 ( \sigma^{(k)} ) 为当前batch中第( k )个特征的均值和标准差,( \epsilon )为极小值防止除零。

  • 可学习参数:标准化后引入可学习的缩放参数( \gamma )和平移参数( \beta ),恢复数据的表达能力: [ y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)} ]

  • 推理阶段:训练时统计的均值和方差通过移动平均保存,推理时直接使用全局统计量而非batch统计量。

Batch Normalization 代码实现(PyTorch)

以下是一个完整的BN层实现示例,包含训练和推理模式的支持:

import torch
import torch.nn as nn

class BatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        
        # 可学习参数
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # 推理时使用的全局统计量
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)
        
    def forward(self, x):
        if self.training:
            # 训练模式:使用当前batch的统计量
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            
            # 更新全局统计量(移动平均)
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            # 推理模式:使用预存的全局统计量
            mean = self.running_mean
            var = self.running_var
        
        # 标准化计算
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        y = self.gamma * x_hat + self.beta
        return y

使用示例

# 初始化BN层(假设输入特征维度为5)
bn = BatchNorm1d(5)

# 训练模式
bn.train()
x = torch.randn(32, 5)  # batch_size=32
y = bn(x)

# 推理模式
bn.eval()
y_inference = bn(x)

注意事项

  • PyTorch原生实现:实际使用时可直接使用torch.nn.BatchNorm1d等官方接口,上述代码仅为教学目的。
  • 与Dropout的交互:BN层通常放在激活函数之前,且在Dropout层之后使用。
  • batch_size影响:小batch可能导致统计量估计不准,可尝试GroupNormLayerNorm替代。

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shayudiandian

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值