Batch Normalization
暑期实习面试真题!手写batch normalization实现。
torch中源码:https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
实现思想
Batch Normalization的主要思想是在训练过程中对每个小批量的输入进行标准化,以使网络更稳定、更快地收敛,并且有时还可以提高泛化能力。
1、在训练阶段,对于每个输入的小批量数据,计算该批量数据的均值和方差。这些统计信息用于标准化数据。
2、对于每个输入特征,将其标准化为零均值和单位方差。这是通过减去批量的均值,然后除以批量的标准差实现的。
3、训练过程中,每个Batch Normalization层会根据当前小批量数据的均值和方差来更新内部保存的滑动平均值。这些滑动平均值会在训练过程中累积,并在测试过程中使用。
4、测试过程中,使用保存的滑动平均值来进行归一化。
公式
x=x−xˉVar(x)+epsx =\frac{ x - \bar{x} }{\sqrt{Var(x)+eps} }x=Var(x)+epsx−xˉ