由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文将介绍两种在小batchsize也可以发挥BN性能的方法。
本文首发自极市平台,作者 @皮特潘,转载需获授权。
前言
BN(Batch Normalization)几乎是目前神经网络的必选组件,但是使用BN有两个前提要求:
- batchsize不能太小;
- 每一个minibatch和整体数据集同分布。
不然的话,非但不能发挥BN的优势,甚至会适得其反。但是由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文介绍两篇在小batchsize也可以发挥BN性能的方法。解决思路为:既然batchsize太小的情况下,无法保证当前minibatch收集到的数据和整体数据同分布。那么能否多收集几个batch的数据进行统计呢?这两篇工作分别分别是:
- BRN:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
- CBN:Cross-Iteration Batch Normalization
另外,本文也会给出代码解析,帮助大家理解。
batchsize过小的场景
通常情况下,大家对CNN任务的研究一般为公开的数据集指标负责。分类任务为ImageNet数据集负责,其尺度为224X224。检测任务为coco数据集负责,其尺度为640X640左右。分割任务一般为coco或PASCAL VOC数据集负责,后者的尺度大概在500X500左右。再加上例如resize的前处理操作,真正送入网络的图片的分辨率都不算太大。一般性能的GPU也很容易实现大的batchsize(例如大于32)的支持。
但是实际的项目中,经常遇到需要处理的图片尺度过大的场景,例如我们使用500w像素甚至2000w像素的工业相机进行数据采集,500w的相机采集的图片尺度就是2500X2000左右。而对于微小的缺陷检测、高精度的关键点检测或小物体的目标检测等任务,我们一般不太想粗暴降低输入图片的分辨率,这样违背了我们使用高分辨率相机的初衷,也可能导致丢失有用特征。在算力有限的情况下,我们的batchsize就无法设置太大,甚至只能为1或2。小的batchsize会带来很多训练上的问题,其中BN问题就是最突出的。虽然大batchsize训练是一个共识,但是现实中可能无法具有充足的资源,因此我们需要一些处理手段。
BN回顾
首先Batch Normalization 中的Normalization被称为标准化,通过将数据进行平和缩放拉到一个特定的分布。BN就是在batch维度上进行数据的标准化。BN的引入是用来解决 internal covariate shift 问题,即训练迭代中网络激活的分布的变化对网络训练带来的破坏。BN通过在每次训练迭代的时候,利用minibatch计算出的当前batch的均值和方差,进行标准化来缓解这个问题。虽然How Does Batch Normalization Help Optimization 这篇文章探究了BN其实和Internal Covariate Shift (ICS)问题关系不大,本文不深入讨论,这个会在以后的文章中细说。
一般来说,BN有两个优点:
- 降低对初始化、学习率等超参的敏感程度,因为每层的输入被BN拉成相对稳定的分布,也能加速收敛过程。
- 应对梯度饱和和梯度弥散,主要是对于使用sigmoid和tanh的激活函数的网络。
当然,BN的使用也有两个前提:
- minibatch和全部数据同分布。因为训练过程每个minibatch从整体数据中均匀采样,不同分布的话minibatch的均值和方差和训练样本整体的均值和方差是会存在较大差异的,在测试的时候会严重影响精度。
- batchsize不能太小,否则效果会较差,论文给的一般性下限是32。
再来回顾一下BN的具体做法:
- 训练的时候:使用当前batch统计的均值和方差对数据进行标准化,同时优化优化gamma和beta两个参数。另外利用指数滑动平均收集全局的均值和方差。
- 测试的时候:使用训练时收集全局均值和方差以及优化好的gamma和beta进行推理。
可以看出,要想BN真正work,就要保证训练时当前batch的均值和方差逼近全部数据的均值和方差。
BRN
论文题目:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
论文地址: https://arxiv.org/pdf/1702.03275.pdf
代码地址: https://github.com/ludvb/batchrenorm
核心解析:
本文的核心思想就是:训练过程中,由于batchsize较小,当前minibatch统计到的均值和方差与全部数据有差异,那么就对当前的均值和方差进行修正。修正的方法主要是利用到通过滑动平均收集到的全局均值和标准差。看公式:
x i − μ σ = x i − μ B σ B ⋅ r + d , where r = σ B σ , d = μ B − μ σ \frac{x_{i}-\mu}{\sigma}=\frac{x_{i}-\mu_{\mathcal{B}}}{\sigma_{\mathcal{B}}} \cdot r+d, \quad \text { where } r=\frac{\sigma_{\mathcal{B}}}{\sigma}, \quad d=\frac{\mu_{\mathcal{B}}-\mu}{\sigma} σxi−μ=σBxi−μB⋅r+d, where r=σσB,d=σμB−μ
上面公式中,i表示网络的第i层。μ和σ表示网络推理时的均值和标准差,也就是训练过程通过滑动平均收集的到均值和方差。μB和σb表示当前训练迭代过程中的实际统计到的均值和标准差。BN在小batch不work的根本原因就是这两组参数存在较大的差异。通过r和d对训练过程中数据进行线性变换,在该变化下,上公式左右两端就严格相等了。其实标准的BN就是r=1,d=0的一种情况。对于某一个特定的minibatch,其中r和d可以看成是固定的,是直接计算出来的,不需要梯度优化的。
具体流程:
-
统计当前batch数据的均值和标注差,和标准BN做法一致。
-
根据当前batch的均值和标准差结合全局的均值和标准差利用上面的公式计算r和d;注意该运算是不参与梯度反向传播的。另外,r和d需要增加一个限制,直接clip操作就好。
-
利用当前的均值和标准差对当前数据执行Normalization操作,利用上面计算得到的r和d对当前batch进行线性变换。
-
滑动平均收集全局均值和标注差。
测试过程和标准BN一样。其实本质上,就是训练的过程中使用全局的信息进行更新当前batch的数据。间接利用了全局的信息,而非当前这一个batch的信息。
实验效果:
在较大的batchsize(32)的时候,与标准BN相比,不会丢失效果,训练过程一如既往稳定高效。如下:
在小的batchsize(4)下, 本文做法依然接近batchsize为32的时候,可见在小batchsize下是work的。
代码解析:
def forward(self, x):
if x.dim() > 2:
x = x.transpose(1, -1)
if self.training: # 训练过程
dims = [i for i in range(x.dim() - 1)
batch_mean = x.mean(dims) # 计算均值
batch_std = x.std(dims, unbiased=False) + self.eps # 计算标准差
# 按照公式计算r和d
r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax)
d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean))
/ self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax)
# 对当前数据进行标准化和线性变换
x = (x - batch_mean) / batch_std * r + d
# 滑动平均收集全局均值和标注差
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
self.running_std += self.momentum * (batch_std.detach() - self.running_std)
self.num_batches_tracked += 1
else: # 测试过程
x = (x - self.running_mean) / self.running_std
return x
CBN
论文题目:Cross-Iteration Batch Normalization
论文地址:https://arxiv.org/abs/2002.05712
代码地址:https://github.com/Howal/Cross-iterationBatchNorm
本文认为BRN的问题在于它使用的全局均值和标准差不是当前网络权重下获取的,因此不是exactly正确的,所以batchsize再小一点,例如为1或2时就不太work了。本文使用泰勒多项式逼近原理来修正当前的均值和标准差,同样也是间接利用了全局的均值和方差信息。简述就是:当前batch的均值和方差来自之前的K次迭代均值和方差的平均,由于网络权重一直在更新,所以不能直接粗暴求平均。本文而是利用泰勒公式估计前面的迭代在当前权重下的数值。
泰勒公式:
泰勒公式是一个用函数在某点的信息描述其附近取值的公式。如果函数满足一定的条件,泰勒公式可以用函数在某一点的各阶导数值做系数构建一个多项式来近似表达这个函数。教科书介绍如下:
核心解析:
本文做法,由于网络一般使用SGD更新权重,因此网络权重的变化是平滑的,所以适用泰勒公式。如下,t为训练过程中当前迭代时刻,t-τ为t时刻向前τ时刻。θ为网络权重,权重下标代表该权重的时刻。μ为当前minibatch均值,v为当强minibatch平方的均值,是为了计算标准差。因此直接套用泰勒公式得到:
μ t − τ ( θ t ) = μ t − τ ( θ t − τ ) + ∂ μ t − τ ( θ t − τ ) ∂ θ t − τ ( θ t − θ t − τ ) + O ( ∥ θ t − θ