tensorflow保存模型和恢复模型

本文详细介绍使用TensorFlow保存和恢复模型的过程。首先通过构建计算图并初始化变量,然后使用Saver类保存模型到文件。恢复模型时,通过导入元图和最新检查点,重新加载模型状态,并进行预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

保存模型

w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}


w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#创建saver的实例
saver = tf.train.Saver()

#打印w4
print(sess.run(w4,feed_dict))
#w4=(w1+w2)*b1,值为24

#保存权重
saver.save(sess, 'my_test_model',global_step=1000)

恢复模型

import tensorflow as tf

sess=tf.Session()    
#加载graph
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))

#直接访问已保存的变量
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved

#准备网络的输入
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#访问想要运行的操作 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#打印出60
print(sess.run(op_to_restore,feed_dict))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

upDiff

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值