在tf中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,初始化时需要提供衰减率(decay),这个衰减率将用于控制模型的更新速度,ExponentialMovingAverage对每个变量会维护一个影子变量:shadow_variable = decay × shadow_variable + (1 - decay) × variable,shadow_variable为影子变量,variable为待更新变量,decay一般设置成十分接近1的数,ExponentialMovingAverage还提供了num_updates参数来动态设置decay的大小使得训练前期可以更新的更快而在后期更新速递递减,每次使用的衰减率将是:min{decay, (1+num_updates)/(10+num_updates)},下面代码是解释ExponentialMovingAverage如何被使用的。
首先是定义变量,tf.Varialbe():trainable: 如果为True(默认也为Ture),这个变量就会被添加到图的集合GraphKeys.TRAINABLE_VARIABLES.中去 ,这个collection被作为优化器类的默认列表,这里设置成了False,设定trainable=False 可以防止该变量被数据流图的GraphKeys.TRAINABLE_VARIABLES 收集, 这样我们就不会在训练的时候尝试更新它的值。
接着定义ExponentialMovingAverage,设定两个参数decay和num_updates。
ema.apply([v1])定义更新变量滑动平均的操作,参数是需要给定一个列表,每次执行这个操作时这个列表中的变量将会更新。
接着每次调用sess.run(maintain_averages_op)时v1的平均滑动会被更新,更新公式为decay × shadow_variable + (1 - decay) × variable,第一次运行时shadow_variable影子变量值为0,待更新参数v1为5,由于num_updates参数=step=0,decay = min{decay,(1+num_updates)/(10+num_updates)} = min{0.99, 1/10} = 0.1,故通过公式计算得到的滑动平均值为ema.average(v1)=4.5,所以在打印输出结果中print(sess.run([v1, ema.average(v1)]))将输出[5.0, 4.5],后面运行sess.run(maintain_averages_op)之前更新了variable和num_updates,但是shadow_variable都是承接着上一次更新公式的结果,比如第二次num_updates=step更新成了10000,待更新变量v1更新成了10,但shadow_variable还承接着上一次更新结果所以shadow_variable=4.5,再通过公式decay × shadow_variable + (1 - decay) × variable可以计算得到滑动平均值为ema.average(v1)=4.555,所以第二次print(sess.run([v1, ema.average(v1)]))得到结果为[10.0, 4.555],第三次类推。
import tensorflow as tf
v1 = tf.Variable(0, dtype=tf.float32)
step = tf.Variable(0, trainable=False)
ema = tf.train.ExponentialMovingAverage(0.99, num_updates=step)
maintain_averages_op = ema.apply([v1])
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run([v1, ema.average(v1)]))
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
得到结果
[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.555]
[10.0, 4.60945]