BN踩坑记--谈一下Batch Normalization的优缺点和适用场景

BN踩坑:BatchNormalization的优缺点与适用场景解析
本文讨论了BatchNormalization(BN)的优缺点,包括解决内部协变量偏移和梯度饱和,但强调了在小批量、RNN和测试阶段的局限性。文章指出BN在MLP和CNN中表现良好,但在动态文本模型如RNN中的效果较差,推荐关注层归一化在NLP中的应用。

BN踩坑记–谈一下Batch Normalization的优缺点和适用场景

这个问题没有定论,很多人都在探索,所以只是聊一下我自己的理解,顺便为讲 layer-norm做个引子。

BN的理解重点在于它是针对整个Batch中的样本在同一维度特征在做处理。

在MLP中,比如我们有10行5列数据。5列代表特征,10行代表10个样本。是对第一个特征这一列(对应10个样本)做一次处理,第二个特征(同样是一列)做一次处理,依次类推。

在CNN中扩展,我们的数据是N·C·H·W。其中N为样本数量也就是batch_size,C为通道数,H为高,W为宽,BN保留C通道数,在N,H,W上做操作。比如说把第一个样本的第一个通道的数据,第二个样本第一个通道的数据…第N个样本第一个通道的数据作为原始数据,处理得到相应的均值和方差。

BN有两个优点。

第一个就是可以解决内部协变量偏移,简单来说训练过程中,各层分布不同,增大了学习难度,BN缓解了这个问题。当然后来也有论文证明BN有作用和这个没关系,而是可以使损失平面更加的平滑,从而加快的收敛速度。

第二个优点就是缓解了梯度饱和问题(如果使用sigmoid激活函数的话),加快收敛。

BN的缺点:

第一个,batch_size较小的时候,效果差。这一点很容易理解。BN的过程,使用 整个batch中样本的均值和方差来模拟全部数据的均值和方差,在batch_size 较小的时候,效果肯定不好。

第二个缺点就是 BN 在RNN中效果比较差。这一点和第一点原因很类似,不过我单挑出来说。

首先我们要意识到一点,就是RNN的输入是长度是动态的,就是说每个样本的长度是不一样的。

举个最简单的例子,比如 batch_size 为10,也就是我有10个样本,其中9个样本长度为5,第10个样本长度为20。

那么问题来了,前五个单词的均值和方差都可以在这个batch中求出来从而模型真实均值和方差。但是第6个单词到底20个单词怎么办?

只用这一个样本进行模型的话,不就是回到了第一点,batch太小,导致效果很差。

第三个缺点就是在测试阶段的问题,分三部分说。

首先测试的时候,我们可以在队列里拉一个batch进去进行计算,但是也有情况是来一个必须尽快出来一个,也就是batch为1,这个时候均值和方差怎么办?

这个一般是在训练的时候就把均值和方差保存下来,测试的时候直接用就可以。那么选取效果好的均值和方差就是个问题。

其次在测试的时候,遇到一个样本长度为1000的样本,在训练的时候最大长度为600,那么后面400个单词的均值和方差在训练数据没碰到过,这个时候怎么办?

这个问题我们一般是在数据处理的时候就会做截断。

还有一个问题就是就是训练集和测试集的均值和方差相差比较大,那么训练集的均值和方差就不能很好的反应你测试数据特性,效果就回差。这个时候就和你的数据处理有关系了。

BN使用场景

对于使用场景来说,BN在MLP和CNN上使用的效果都比较好,在RNN这种动态文本模型上使用的比较差。至于为啥NLP领域BN效果会差,Layer norm 效果会好,下一个文章会详细聊聊我的理解。

列一下参考资料:

模型优化之Batch Normalization - 大师兄的文章 - 知乎 https://zhuanlan.zhihu.com/p/54171297

这个文章写的很好,推荐,从BN的特点(ICS/梯度饱和),训练,测试以及损失函数平滑都讲了一下。

李宏毅- Batch Normalization https://www.bilibili.com/video/av16540598/

大佬的讲解视频,不解释,推荐

各种Normalization - Mr.Y的文章 - 知乎 https://zhuanlan.zhihu.com/p/86765356

这个文章关于BN在CNN中使用的讲解很好,推荐一下。

### Batch Normalization的工作原理 Batch NormalizationBN)是一种在深度学习中广泛使用的正则化技术,其核心目标是通过标准化每一层的输入来加速训练过程并提高模型的稳定性。BN 的基本思想是对每个小批量(mini-batch)中的数据进行标准化处理,使得每一层的输入分布保持在均值为 0、方差为 1 的标准正态分布附近。这一过程有助于缓解梯度消失梯度爆炸问题,从而提升模型的训练效率泛化能力。 BN 的本质是通过强制拉回神经元输入分布到标准正态分布,使非线性激活函数的输入值落在对输入敏感的区域,从而避免梯度消失问题。具体而言,在训练过程中,BN 会计算当前 batch 数据的均值方差,并利用这两个统计量对输入进行标准化。随后,通过学习的缩放因子(scale)偏移因子(shift)对标准化后的数据进行线性变换,使得网络能够恢复原始的输入分布[^3]。 ### Batch Normalization的实现步骤 Batch Normalization 的实现过程可以分为以下几个步骤: 1. **计算均值方差** 对于当前 batch 的输入数据 $ x $,计算其均值 $ \mu_B $ 方差 $ \sigma_B^2 $,其中: $$ \mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i $$ $$ \sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2 $$ 其中 $ m $ 是 batch 的大小 [^1]。 2. **标准化输入数据** 利用上述计算得到的均值方差对输入数据进行标准化: $$ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} $$ 其中 $ \epsilon $ 是一个很小的常数,用于防止除以零的情况 [^1]。 3. **应用可学习参数进行线性变换** 引入两个可学习的参数 $ \gamma $ $ \beta $,对标准化后的数据进行缩放平移: $$ y_i = \gamma \cdot \hat{x}_i + \beta $$ 这一步允许网络根据需要调整标准化后的数据分布,从而恢复原始输入的表达能力 [^3]。 4. **训练过程中维护全局均值方差** 在训练阶段,BN 层会维护一个滑动平均的均值方差,用于推理阶段的标准化。这些全局统计量通常采用指数加权平均的方式更新: $$ \text{running\_mean} = \text{momentum} \cdot \text{running\_mean} + (1 - \text{momentum}) \cdot \mu_B $$ $$ \text{running\_var} = \text{momentum} \cdot \text{running\_var} + (1 - \text{momentum}) \cdot \sigma_B^2 $$ 其中 `momentum` 是一个超参数,通常设为 0.9 。 ### Batch Normalization的数学推导 BN 的数学推导主要围绕如何减少内部协方差偏移(Internal Covariate Shift)展开。内部协方差偏移是指在神经网络训练过程中,由于前面层的参数更新导致后面层输入分布发生变化的现象。BN 通过标准化输入数据,使得每一层的输入分布趋于稳定,从而加快训练速度并提升模型的收敛性 [^2]。 在反向传播过程中,BN 的梯度计算涉及均值方差的导数,这使得梯度的传播更加稳定。具体而言,BN 通过标准化操作将输入拉回到对激活函数敏感的区域,从而避免梯度消失或爆炸的问题 [^4]。 ### Batch Normalization的PyTorch实现 以下是一个简单的 Batch Normalization 实现示例,展示了其在 PyTorch 中的实现方式: ```python import torch import torch.nn as nn class BatchNorm(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.9): super(BatchNorm, self).__init__() self.eps = eps self.momentum = momentum self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.gamma = nn.Parameter(torch.ones(num_features)) self.beta = nn.Parameter(torch.zeros(num_features)) def forward(self, x): if self.training: batch_mean = x.mean(dim=0) batch_var = x.var(dim=0, unbiased=False) # 更新 running mean running var self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var # 标准化 x_norm = (x - batch_mean) / torch.sqrt(batch_var + self.eps) else: x_norm = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps) return self.gamma * x_norm + self.beta ``` 在上述代码中,`running_mean` `running_var` 是用于推理阶段的全局统计量,而 `gamma` `beta` 是可学习的参数,用于恢复输入的原始分布。 ### 相关问题
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

onnx

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值