批标准化(Batch Normalization, BN)
我们知道在深层神经网络训练时,会产生梯度消失问题,梯度消失的一个主要来源就是激活函数,以sigmoid函数为例,自变量在0附近导数较大,自变量越大或者越小都会造成导数缩小,如果数据分布在过小或者过大区域,过小的导数就很可能产生梯度消失问题(为何梯度消失要看激活函数的导数,这涉及到反向传播推导,可以看我的另一篇blog)。在比如,如果我们的数据大量分布在远离0的点,例如20和200,差了10倍,但是经过sigmoid激活后,输出值十分接近(我用torch.sigmoid()函数计算,结果全是1. ,可见差距是非常的小),这种过强的非线性,就很难体现出数据的差异性。为了解决这些问题,批标准化就被提出。
批标准化的基本思路是把数据拉扯到0附近,让数据满足均值为0,方差为1的正态分布,说到这里,一些人可能会想到某些方法,例如Z-Score标准化,其实批标准化就和Z-Score标准化很像,只不过我们不仅仅对输入做标准化,我们还在每一隐层的输入激活函数前,也加一层标准化。批标准化一般表示成 y = B N γ , β ( x ) y=BN_{\gamma, \beta}(x) y=BNγ,β(x)。 B N γ , β ( ⋅ ) BN_{\gamma, \beta}(·) BNγ,β(⋅)由下面四个操作构成:

上图来自论文Batch normalization: accelerating deep network training by reducing internal covariate shift,我先解释一下输入,假设目前我们想要进行批标准化的是隐层 l l l,该隐层输入的size是(B,N),其中B是batch size,N是当前隐层 l l l需要进行批标准化的神经元的个数,如果B=3,N=5,那么当前隐层输入的值,可以表示成下面的式子 [ [ x 11 x 12 x 13 x 14 x 15 ] [ x 21 x 22 x 23 x 24 x 25 ] [ x 31 x 32 x 33 x 34 x 35 ] ] \left[ \begin{matrix} [x_{11} & x_{12} & x_{13} & x_{14} & x_{15} & ] \\ [x_{21} & x_{22} & x_{23} & x_{24} & x_{25} & ] \\ [x_{31} & x_{32} & x_{33} & x_{34} & x_{35} & ] \end{matrix} \right] ⎣⎡[x11[x21[x31x12x22x32x13x23x33x14x24x34x15x25x35]]]⎦⎤
在Algoritm.1算法中,batch中的一个 x i x_{i} xi就是 [ x i 1 , x i 2 , x i 3 , x i 4 , x i 5 ] [x_{i1}, x_{i2}, x_{i3}, x_{i4}, x_{i5}] [xi1,xi2,xi3,xi4,xi5],得到的 μ B , σ B 2 , x ^ i , y i \mu_{B}, \sigma^{2}_{B},\hat{x}_{i}, y_{i} μB,σB2,x^i,yi全是五维的,所以上面的公式其实是向量运算。对于全连接层的批标准化,其实是考虑单个神经元的值为一个集体,这一集体有batch size个元素,在这一集体进行上面的四步操作,神经元间是独立的,同一隐层每个神经元各求得一个 μ B , σ B 2 \mu_{B}, \sigma^{2}_{B} μB,σB2。也就是说, x 11 , x 21 , x 31 x_{11}, x_{21}, x_{31} x11,x21,x31求一个均值,方差; x 12 , x 22 , x 33 x_{12}, x_{22}, x_{33} x12,x22,x33求一个均值,方差……
至于为何会存在 γ , β \gamma, \beta γ,β,是因为我们不想让数据经过批标准化后真的落在0附近,以sigmoid为例,0附近接近线性,而我们的神经网络需要依靠激活函数提供分线性,所以会加一个回退操作,这就是 γ , β \gamma, \beta γ,β的所用, γ , β \gamma, \beta γ,β可以通过反向传播算法进行学习, γ , β \gamma, \beta γ,β在上面例子中也是五维的。
推断过程
我们知道训练数据时一般会用mini-batch,但是训练后作用于验证集或测试集时,只会一次带入一条数据,没有均值和方差,我们如何对单条数据进行批标准化呢?针对这种情况,我们利用训练集中的均值和方差,下图是论文中的原图,1-6步是训练过程,7-12步是推断过程,

从图中我们可以看出来,推断过程中,我们利用训练集的所有batch的均值 μ B \mu_{B} μB,方差 σ B 2 \sigma_{B}^{2} σB2,来通过公式 E [ x ] ← E B [ μ B ] V a r [ x ] ← m m − 1 E B [ σ B 2