上一篇博文《TensorFlow模型参数的保存和加载》介绍了如何保存和加载TensorFlow模型训练参数,保存对象主要是Tensor/Variables。这一节我们介绍如何保存和复用op。
和Tensor一样,保存op需要在训练时为op指定名字,如下所示:
softmax = tf.nn.softmax(tf.matmul(x, W) + b,name="op_softmax")
在识别阶段,调用get_operation_by_name()函数,以op名字作为参数,如下所示:
op_softmax =sess.graph.get_operation_by_name("op_softmax").outputs[0]
这里需要注意的是,在最后面要加上“.outputs[0]”,否则会出现异常。
如果要直接运行sess.run(op_softmax),需要指定feed_dict。以官方mnist训练案例为例,调用格式为sess.run(op_softmax,feed_dict={x: mnist.test.images})。
在加载复用阶段,W和b的值已经保存在checkpoint数据中,故不需要再次声明W和b。但是,需要通过get_tensor_by_name()获取到x的声明,如下所示:
x = sess.graph.get_tensor_by_name("x:0")
上述op保存和加载操作可总结为:
1. 在训练代码中,为op指定名字;
2. 在复用阶段,通过get_operation_by_name().outputs[0]获取op;
3. 通过get_tensor_by_name()获取到feed_dict输入tensor(即x)