1.Batch normalization的公式
其中:xix_{i}xi是输入,μB\mu _{B}μB是均值,σB2\sigma _{B}^{2}σB2是方差,γγγ是缩放系数(scale),βββ是偏移(offset)系数,ε\varepsilonε是方差偏移系数,BN(xi)BN(x_{i})BN(xi)是输出。
2. Batch normalization介绍
批标准化(batch normalization,BN),一般用在激活函数之前,使结果y=wx+by=wx+by=wx+b,各个维度参数均值为0,方差为1。通过规范化让激活函数的输入分布在线性区间,让每一层的输入有一个稳定的分布会有利于网络的训练。
优点:
- 加大探索步长,加快收敛速度。
- 更容易跳出局部极小。
- 破坏原来的数据分布,一定程度上防止过拟合。
- 解决收敛速度慢和梯度爆炸。
3. Batch normalization的tensorflow API
3.1
mean, variance = tf.nn.moments(x, axes, name=None, keep_dims=False)
计算统计矩,mean 是一阶矩即均值,variance 则是二阶中心矩即方差,axes=[0]表示按列计算;
3.2
tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
tf.nn.batch_norm_with_global_normalization(x, mean, variance, beta, gamma, variance_epsilon, scale_after_normalization, name=None);
tf.nn.moments 计算返回的 mean 和 variance 作为 tf.nn.batch_normalization 参数调用;