TensorFlow模型op的保存和加载(含演示代码)

本文详细讲解了如何在TensorFlow中保存和复用OP,包括为OP指定名称,训练时的保存方法,识别阶段的加载与运行,以及在实际应用中的完整演示代码。强调了在加载OP时,需要通过get_operation_by_name().outputs[0]获取OP,并通过get_tensor_by_name()获取feed_dict输入tensor的声明。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上一篇博文《TensorFlow模型参数的保存和加载》介绍了如何保存和加载TensorFlow模型训练参数,保存对象主要是Tensor/Variables。这一节我们介绍如何保存和复用op。

和Tensor一样,保存op需要在训练时为op指定名字,如下所示:

softmax = tf.nn.softmax(tf.matmul(x, W) + b,name="op_softmax")


在识别阶段,调用get_operation_by_name()函数,以op名字作为参数,如下所示:

op_softmax =sess.graph.get_operation_by_name("op_softmax").outputs[0]


这里需要注意的是,在最后面要加上“.outputs[0]”,否则会出现异常。

如果要直接运行sess.run(op_softmax),需要指定feed_dict。以官方mnist训练案例为例,调用格式为sess.run(op_softmax,feed_dict={x: mnist.test.images})。


在加载复用阶段,W和b的值已经保存在checkpoint数据中,故不需要再次声明W和b。但是,需要通过get_tensor_by_name()获取到x的声明,如下所示:

x = sess.graph.get_tensor_by_name("x:0")


上述op保存和加载操作可总结为:

1. 在训练代码中,为op指定名字;

2. 在复用阶段,通过get_operation_by_name().outputs[0]获取op;

3. 通过get_tensor_by_name()获取到feed_dict输入tensor(即x)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值