参考链接:
https://www.zhihu.com/question/38102762
https://zhuanlan.zhihu.com/p/26138673
https://blog.youkuaiyun.com/hjimce/article/details/50866313
https://blog.youkuaiyun.com/myarrow/article/details/51848285
###原理
BN的本质是解决了反向传播过程中的梯度消失问题。
####梯度消失/爆炸问题
- 前向传播:hl+1=Wlhlh_{l+1} = W_lh_l hl+1=Wlhl
- 反向传播:
梯度求解的一般形式∂L∂hi=∂L∂hl∂hl∂hl−1...∂hi+1∂hi\frac{\partial L}{\partial h_i} = \frac{\partial L}{\partial h_l}\frac{\partial h_l}{\partial h_{l-1}}...\frac{\partial h_{i+1}}{\partial h_i}∂hi∂L=∂hl∂L∂hl−1∂hl...∂hi∂hi+1
KaTeX parse error: No such environment: eqnarray at position 7: \begin{̲e̲q̲n̲a̲r̲r̲a̲y̲}̲ \frac{\parti…
上式的结果为权重的连乘,我们知道:0.930=0.040.9^{30} = 0.040.930=0.04 , 1.130=17.41.1^{30} = 17.41.130=17.4,这就解释了为什么会出现梯度消失和爆炸问题。
BN的实质是网络输出的变换。令x为某一卷积网络层的输出,则BN变换Y=BN(x)Y = BN(x)Y=BN(x)如下:
xˉ=1M∑i=iMxi\bar x = \frac{1}{M}\sum_{i=i}^Mx_ixˉ=M1i=i∑Mxi
σx=1M∑i=iM(xi−xˉ)2\sigma_x = \frac{1}{M}\sum_{i=i}^M(x_i-\bar x)^2σx=M1i=i∑M(xi−xˉ)2
x^=x−xˉσx+ϵ\hat x = \frac{x - \bar x}{\sqrt{\sigma_x + \epsilon}} x^=σx+ϵx−xˉ
Y=γx^+βY= \gamma \hat x + \betaY=γx^+β
那么BN是如何解决梯度消失和爆炸问题的呢?
主要思想:解决scale对梯度的影响,让BN变换至少具有能恢复原始数据的能力。
∂Yl+1∂hl=∂BN(hl+1)∂hl=∂BN(Wlhl)∂hl=∂BN(αWlhl)∂hl\frac{\partial Y_{l+1}}{\partial h_l} =\frac{\partial BN(h_{l+1})}{\partial h_l} = \frac{\partial BN(W_lh_l)}{\partial h_l} = \frac{\partial BN(\alpha W_lh_l)}{\partial h_l} ∂hl∂Yl+1=∂hl∂BN(hl+1)=∂hl∂BN(Wlhl)=∂hl∂BN(αWlhl)
不管参数变化多大,传回上一层的梯度∂Yl+1∂hl\frac{\partial Y_{l+1}}{\partial h_l}∂hl∂Yl+1始终不变,不受尺度scale的影响。
∂Yl+1∂Wl=∂BN(hl+1)∂Wl=∂BN(Wlhl)∂Wl=1α∂BN(αWlhl)∂Wl\frac{\partial Y_{l+1}}{\partial W_l} =\frac{\partial BN(h_{l+1})}{\partial W_l} = \frac{\partial BN(W_lh_l)}{\partial W_l} = \frac{1}{\alpha} \frac{\partial BN(\alpha W_lh_l)}{\partial W_l} ∂Wl∂Yl+1=∂Wl∂BN(hl+1)=∂Wl∂BN(Wlhl)=α1∂Wl∂BN(αWlhl)
对用于更新参数W的梯度∂Yl+1∂Wl\frac{\partial Y_{l+1}}{\partial W_l}∂Wl∂Yl+1,如果Wl′=αWlW_l' = \alpha W_lWl′=αWl, 则grad(W′)=1αgrad(W)grad(W') =\frac{1}{\alpha} grad(W)grad(W′)=α1grad(W)。如果α<1\alpha< 1α<1,则1α>1\frac{1}{\alpha} > 1α1>1 说明尺度较大的参数会获得比较小的梯度;相反,尺度较小的参数会获得比较大的梯度,使得整个网络的参数更新变得更加稳健(所以我们最后参数会趋向于同样大小?)
###面试常问问题
- BN怎么回事?什么原理?
- BN中有两个参数 γ\gammaγ 和 β\betaβ后的均值和方差在训练和预测的时候需要怎么处理?
BN中有两个参数 γ\gammaγ 和 β\betaβ,这个两个参数怎么回事,有什么需要注意的?
这两个参数是可学习的参数。(其实每个BN都包含两个这样的参数)
训练的时候记录每个mini-batch的均值 μ\muμ 和方差 σ2\sigma^2σ2,最后在测试的时候,用均值 μ\muμ 和方差 σ2\sigma^2σ2的无偏估计来计算。
(在pytorch中,一般用momentum来更新Inference时使用的均值 μ\muμ 和方差 σ2\sigma^2σ2。具体来说,xnew=x∗(1−momentum)+momentim∗xtx_{new} = x * (1 - momentum) + momentim * x_txnew=x∗(1−momentum)+momentim∗xt
参考:https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)。
- BN和Hisssian矩阵的关系