1)Batch Normalization解决的问题
Batch Normalization(BN)主要用于解决Internal Covariate Shift。由于训练过程中,网络各层数据x分布会发生变化(偏移),这个偏移可能是受不同batch间(或者训练集和测试集)的数据本身分布不同,或者是在训练过程,由于梯度回传,导致不同batch间各层数据分布前后不一致。
这个现象会导致模型训练更为困难,而且由于某些层数据偏移如果过大,导致其经过某些激活层(sigmoid函数等)后其梯度会趋于0,从而造成梯度消失的问题。
早期解决这个问题是对输入进行归一化到的高斯分布(白化),在网络层较小时,通过此类对输入数据初始化操作确实能解决这类问题,但随着网络层数加大,中间层仍然会出现偏移过大的情况。
一种直观的解决思路是对每一层都进行归一化操作,但这样会破坏每层原本的表达,特别是对于某些激活函数(如sigmoid函数),归一化操作还会使得数据分布都在线性区域,而失去了非线性的表达。
2)Batch Normalization的计算方式
BN的思路是在归一化操作后,加一个还原操作,通过这种方式实现训练过程中减少数据分布变化带来的偏移。
上图显示了BN的计算过程,其中在训练过程中,表示batch内的样本标签,
表示batch内一条输入样本,