TensorFlow中关于saver读取中MovingAverage的一些注意事项

本文介绍TensorFlow中滑动平均模型的保存与恢复方法,包括如何利用ExponentialMovingAverage进行变量更新,如何通过variables_to_restore简化模型加载过程。
部署运行你感兴趣的模型镜像

saver中保存滑动平均模型中,当我们直接定义一个滑动平均类的操作后,会自动生成变量列表中所对应的shaddow variables, 具体细节代码如下:

#part 1

v = tf.Variable(0, dtype=tf.float32, name="v")
#创建滑动平均的类,给定初始衰减率0.5
ema = tf.train.ExponentialMovingAverage(0.5) 
#定义一个更新变量滑动平均操作。注意这里需要给定一个列表,每次执行这个操作的时候列表中的变量都会被更新 
ema_op = ema.apply(tf.global_variables())
#请注意,当我们定义好了变量的滑动平均操作之后,会自动生成一个"v/ExponentialMovingAverage"的影子变量
for variable in tf.global_variables():
    print(variable.name)
#输出:
#v:0
#v/ExponentialMovingAverage:0

#给变量赋值看v和v/ExponentialMovingAverage的取值并保存模型:
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.assign(v, 10))
    sess.run(ema_op)
    saver.save(sess, "./save_model/model.ckpt")
    print(sess.run([v, ema.avergae(v)]))       #其中ema.average(v)就代表了v的shaddow variable
#输出:
#[10.0, 5.0]

#part2 从saver中读取数据

v = tf.Variable(3.0, dtype=tf.float32, name="v")
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})      #把上面所保存的滑动平均值加载给v值  
with tf.Session() as sess:
    saver.restore(sess, "./save_model/model.ckpt") 
    print(sess.run(v))
#输出:
#5.0             注意,这个值是从文件中读取的v的滑动平均的值,如果想加载v的原值,直接tf.train.Saver({"v": v})或者tf.train.Saver()就行

#part3  当变量变得很多的时候,通过字典的方式来加载滑动平均值就显得不可能了,所以在TensorFlow中提供了variables_to_restore函数
#来生成tf.train.Saver()类所需的变量重命名字典, 样例如下:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.5)
#通过使用variables_to_restore函数来生成上面tf.train.Saver()所需的字典
print(ema.variables_to_restore())
#输出:
#{"v/ExponentialMovingAverage": <tensorflow.Variable "v:0" shape=() dtype=float32_ref>}

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, "./save_model/model.ckpt")
    print(sess.run(v)          #输出 5.0,  也就是原来模型中变量v的滑动平均值
总结:当我们想运用以及训练好的权重的滑动平均值来预测数据的时候(滑动平均值可以让神经网络模型变得更加健壮)直接通过用variables_to_restore来生成我们所需的变量重命名字典,这样效率跟自己手动定义字典相比大幅提升,而且代码简洁。







您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值