在TensorFlow中,如果我们要使用batch normalization层,可以使用的API有tf.layers.batch_normalization和tf.contrib.layers.batch_norm,如果我们直接使用这两个API构建我们的网络,往往会出现训练的时候网络的表现非常好,而当测试的时候我们将其中的参数is_training设置为False时,网络的表现非常的差。
这种时候是因为我们并没有在训练过程中对BN层的参数进行更新。需要在代码的对应位置加入如下两行代码
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)