巨坑提醒:tf.keras与tensorflow混用,trainable=False根本不起作用。正文不用看了。
摘要:
在tensorflow中,training参数和trainable参数是两个不同的参数,具有不同的约束功能。
- training: True/False,告诉网络现在是在训练阶段还是在inference(测试) 阶段。
- trainable:True/False, 设置这一层网络的参数变量是可在训练过程中被更新,是可训练还是不可训练。
代码1:tf.keras.layers.BatchNormalization的使用示例,及training参数和trainable的影响分析
import tensorflow.compat.v1 as tf
input = tf.ones([1, 2, 2, 3])
output = tf.keras.layers.BatchNormalization(trainable=True)(input,training=True)
# 手动添加滑动平均节点
ops = tf.get_default_graph().get_operations()
bn_update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type == "AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn_update_ops)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
print(tf.global_variables())
print(tf.trainable_variables())
"""
当 trainable=False,training=False 时,返回:
[[]]
[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32>]
[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32>]
当 trainable=False,training=True 时,返回:
[[<tf.Operation 'batch_normalization/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>, <tf.Operation 'batch_normalization/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>]]
[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32>, <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32>]
[<tf.Variable 'batch

这篇博客探讨了在TensorFlow中,tf.keras.layers.BatchNormalization层的trainable和training参数的区别。训练参数training主要控制网络是否在训练阶段,而trainable参数决定层的权重是否在训练过程中更新。在TensorFlow2.0之前,即使trainable=False,训练阶段仍会使用minibatch的均值和方差进行标准化;但在2.0之后,trainable=False会使层进入推理模式,使用滑动均值和方差。手动添加滑动平均节点对于批标准化的正确运行至关重要。建议在模型构建后立即添加这些节点,以确保批标准化层在训练和测试中的行为符合预期。
最低0.47元/天 解锁文章
2万+





