tensorflow版本1.4
tensorflow目前还没实现完全封装好的Batch Normalization的实现,这里主要试着实现一下。
关于理论可参见《 解读Batch Normalization》
对于TensorFlow下的BN的实现,首先我们列举一下需要注意的事项:
- (1)需要自动适应卷积层(batch_size*height*width*channel)和全连接层(batch_size*channel);
- (2)需要能够分别处理Training和Testing的情况,Training时需要更新均值和方差,Testing时使用历史滑动平均得到的均值与方差,即需要提供is_training的标志位参数;
- (3)最好提供滑动平均系数可调;
- (4)BN的计算量较大,尽量提高存储与运算效率;
- (5)需要注意alpha和beta参数可以被BP更新,而均值和方差通过计算得到;
- (6)load模型时,历史均值、方差以及alpha和beta参数需要被正常加载;
最终的实现如下:
#coding=utf-8
# util.py 用于实现一些功能函数
import tensorflow as tf
# 实现Batch Normalization
def bn_layer(x,is_training,name='BatchNorm',moving_decay=0.9,eps=1e-5):
# 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
shape = x.shape
assert