机器学习系列3 Batch Normalization总结

   前段时间面试腾讯优图计算机视觉实习,被问到是否理解Batch Normalization,虽然自己能明白它是做了什么有什么作用,但是总觉得没有理解透彻,于是翻出论文阅读了一番并编程实现,下面算是记录一下吧。论文地址:https://arxiv.org/abs/1502.03167

一、BN解决了什么样的问题

  从论文题目我们就可以看出BN是解决“Internal Covariate Shift”问题的,那么什么是“Internal Covariate Shift”?

  随着网络层数的增加和数据量的增多,训练一个深度神经网络变得越来越难,收敛也越来越慢,于是大神们就相处了一种目前最为常用的损失函数的优化方法即随机梯度下降(SGD),把数据分为m个小批量样本,在每次迭代中选择一个小批量来训练,这样就加速了深度神经网络的训练,但同时也带来了模型超参数(如学习率)选择困难的问题,因为每一个层的输入都收它之前隐层参数的影响,因此当网络变得越来越深,超参的微小变化都能带来很大的影响。同时也由于每一层的数据输入分布受前层参数的影响从而导致该层输入的许多维度上的数据分布都被移动到了非线性激活函数的饱和区域从而导致梯度消失的问题。而BN正是消除了每层输入分布的变化,从而放宽SGD对于超参的选择,允许使用更大的学习率从而加快训练速度。与此同时通过将数据分布集中在非线性激活函数的非饱和区域从而避免了梯度消失的问题。

二、BN的基本思想

  BN的思想总的说来就是把神经网络中每一层的输入分布归一化,使每层输入分布都呈均值为0标准差为1的标准正态分布,从而加快训练速度防止梯度消失。让我们来看看标准正态分布长啥样

 

                                                          图1 标准正态分布

  95%的概率x其值落在了[-2,2]的范围内,让我们再看sigmoid激活函数。

 

  不难理解了,经过BN,激活函数的输入呈标准正态分布,也就是将数据向非线性激活函数的线性区域变动,增大导数值、增强反向传播的流动性,这样也就加快了训练速度,防止了梯度消失。另外也可以从归一化的作用上理解。

     

  归一化后分类器对于权重的微小变动不想归一化前那么敏感从而更易于优化。我想不少人都会有这样的疑问,既然将数据分布都移动到了激活函数的线性区域那么这个网络的表达能力就下降了啊,那采用激活函数还有什么意义?其实BN操作在将每层输入数据归一化后还有反归一化的操作,但是将反归一化的参数设为可学习的参数即缩放量与偏移量,这样就使得激活函数的饱和程度的以控制找到一个线性和非线性之间一个较好的平衡,这样既能享受非线性的较强表达能力又能避免梯度消失加快训练速度,这也是我认为很精妙的一个点。

三、如何实现BN

1.训练阶段BN的实现

  BN层是加在全连接层或卷积层后,激活函数之前,若对每层的输入做完全白化的话非常昂贵而且也不是处处可微,因此论文提出两个必要的简化。一是对输入的每个维度都单独做归一化,例如若t层的输入有d个维度,x = (x^{1},x^{2}...,x^{d}),那么我们对x进行归一化处理后的结果应为:

                                          \hat{x}^{(k)} = \frac{\mathrm{x^{(k)} - E[x^{(k)}]} }{\mathrm{\sqrt{Var[x^{(k)}]}}}

  第二个简化是之前提到过BN 是为了解决SGD中超参选择困难的问题提出的,那么我们用于归一化的平均值和标准差是使用当前minibatch的数据得到的。这样便得到了BN的算法如下:

 

                         

  这里值得注意的是最后一行就是之前提及到的反归一化过程这里的两个参数是要学习得到的。我们画出这个算法的计算图从而推导一下反向传播过程中要用到的各个可学习参数的导数表达式。推导过程如下(就在ipad上写一下了)

       推导之后与论文中所给的式子相同

2.测试阶段BN的实现

  在测试时我们当然不能还用minibatch来计算平均、标准差,论文中提出的方法是直接从训练数据取全局统计量来实施BN。但是在看了cs231n作业中对于测试时BN的实施是这样讲的:“During training we also keep an exponentially decaying running mean of the mean of the mean and variance of each feature, and these averages are used to normalize data at test time” 也就是说采用一种类似momentum SGD的方法:

runningmean = momentum * runningmean + (1-momentum) * samplemean

runningvar = momentum * runningvar + (1 - momentum) * samplevar

   这样做更为合理

四、总结BN的作用

  1. 提高网络梯度流

  2. 可允许更大的学习率,可在更广范围的学习率和不同初始值下工作,因此使用了BN训练更加容易

  3. 可看作一种正则化,因为对数据每一维都做了归一化,因此相当于加入了一些抖动,减轻了Dropout的使用。

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值