tensorflow保存和读取模型(通过图.meta)

由于我每过一段时间,去写模型的时候,就会忘记怎么保存和读取模型。于是,我写下这篇博客以用于自己做笔记。如果对大家有所帮助,那就感谢大家赏脸。如果哪里不足,还请大家评论告知。以弥补我自己的不足。

保存模型

保存模型很简单
两行代码就可以解决问题

x=tf.placeholder(tf.float32,[None,28,28,1],name='x')
y=tf.placeholder(tf.int64,[None],name='y')

# print(train_y)
y_=resnet(x,16,[3,3,3,3],10)
loss=tf.losses.sparse_softmax_cross_entropy(y,y_)

with tf.name_scope('train_op'):
    train=tf.train.AdamOptimizer().minimize(loss)

real=tf.argmax(y_,1)
corrent=tf.equal(real,y)
acc=tf.reduce_mean(tf.cast(corrent,tf.float64))
init=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
	sess.run(init)
	saver.save(sess,'model.resnet.ckpt')

上面很多代码省略,主要观看两行代码

saver=tf.train.Saver()
saver.save(sess,'model.resnet.ckpt')

于是乎,模型就保存完毕,在model文件夹生成了四个文件
mnist模型图
上图中,.meta文件夹就是我们模型的图了!

读取模型

有了模型,我们现在就开始读取模型,以便于我们做预测或进行迁移学习

model_path='model/'
saver=tf.train.import_meta_graph(model_path+'model.resnet.ckpt.meta')
with tf.Session() as sess:
	saver.restore(sess,tf.train.latest_checkpoint(model_path))
	graph = tf.get_default_graph()
	x=graph.get_tensor_by_name('x:0')
	y=graph.get_tensor_by_name('y:0')
	y_=graph.get_tensor_by_name('y_/BiasAdd:0')
	Y=sess.run(y_,feed_dict={x:X})

这样读取模型的好处,可以不必重新建立图结构。
如果不知道有的tensor的name,可以用一个方法,在里面,你找到name就可以了

    for op in graph.get_operations():
        print(op)

里面的东西很多,你需要耐心去找。

TensorFlow中,要读取模型并进行预测,通常需要遵循以下步骤: 1. **选择模型格式**: - **Checkpoint (ckpt)**: 这是最常见的模型保存格式,包含了变量及其值。如果使用的是checkpoint模型,你需要先加载变量: ```python with tf.Graph().as_default(): sess = tf.Session() saver = tf.train.import_meta_graph('model.meta') # 使用.meta文件导入元 saver.restore(sess, 'model.ckpt') # 加载具体的ckpt文件 # 获取输入输出节点,如"input_placeholder:0""output_node:0" input_node = sess.graph.get_tensor_by_name('input_placeholder:0') output_node = sess.graph.get_tensor_by_name('output_node:0') # 进行预测 prediction = sess.run(output_node, {input_node: input_data}) ``` - **Frozen Graph (.pb)**: 是经过优化冻结操作(`tf.train freezing`)后的二进制格式,可以节省存储空间并提高加载速度。使用frozen_graph模型,直接加载预训练的pb文件即可: ```python with tf.gfile.GFile('model.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') # 获取输入输出节点 input_node = sess.graph.get_tensor_by_name('input_node:0') output_node = sess.graph.get_tensor_by_name('output_node:0') prediction = sess.run(output_node, {input_node: input_data}) ``` 2. **使用上下文管理器Graph().as_default()**: 这有助于确保在模型预测时不会遇到空错误。 3. **设置Session**: 创建一个会话来运行计算。 4. **获取输入输出节点**: 根据模型结构找到对应的占位符输出节点。 5. **执行预测**: 提供输入数据,并通过会话运行输出节点以得到预测结果。 请注意,具体实现可能因模型结构的不同而有所差异。如果你遇到了"The Session graph is empty"的错误,确认已经正确加载了模型并且中包含了必要的操作。此外,博客[^1]提供了更详细的教程示例,可供参考。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值