输出TensorFlow中checkpoint内变量的几种方法

将TF保存在checkpoint中的变量值输出到txt或npy中的几种简单的可行的方法:

1,最简单的方法,是在有model 的情况下,直接用tf.train.saver进行restore,就像 cifar10_eval.py 中那样。然后,在sess中直接run变量的名字就可以得到变量保存的值。


在这里以cifar10_eval.py为例。首先,在Graph中穿件model。

    with tf.Graph().as_default() as g:  
        images, labels = cifar10.inputs(eval_data=eval_data)  
        logits = cifar10.inference(images)  
        top_k_op = tf.nn.in_top_k(logits, labels, 1)  

然后,通过tf.train.ExponentialMovingAverage.variable_to_restore确定需要restore的变量,默认情况下是model中所有trainable变量的movingaverge名字。并建立saver 对象

 variable_averages = tf.train.ExponentialMovingAverage(  
            cifar10.MOVING_AVERAGE_DECAY)  
        variables_to_restore = variable_averages.variables_to_restore()  
        saver = tf.train.Saver(variables_to_restore)  


variables_to_restore中是变量的movingaverage名字到变量的mapping(就是个字典)。我们可以打印尝试打印里面的变量名,

    for name in variables_to_restore:  
        print(name)  



输出结果为

    softmax_linear/biases/ExponentialMovingAverage



    conv2/biases/ExponentialMovingAverage



    local4/biases/ExponentialMovingAverage



    local3/biases/ExponentialMovingAverage



    softmax_linear/weights/ExponentialMovingAverage



    conv1/biases/ExponentialMovingAverage



    local4/weights/ExponentialMovingAverage



    local3/weights/ExponentialMovingAverage



    conv2/weights/ExponentialMovingAverage



    conv1/weights/ExponentialMovingAverage



然后在中通过run 变量名的方式就可以得到保存在checkpoint中的值,引文sess.run方法得到的是numpy形式的数据,就可以通过np.save或np.savetxt来保存了。


    with tf.Session() as sess:  
      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)  
      if ckpt and ckpt.model_checkpoint_path:  
        # Restores from checkpoint  
        saver.restore(sess, ckpt.model_checkpoint_path)  
        conv1_w=sess.run('conv1/weights/ExponentialMovingAverage')  
 
此时conv1_w就是conv1/weights的MovingAverage的值,并且是numpy array的形式。




2, 第二种方法是使用
tensorflow/python/tools/inspect_checkpoint.py
 中提到的tf.train.NewCheckpointReader类

这种方法不需要model,只要有checkpoint文件就行。
首先用tf.train.NewCheckpointReader读取checkpoint文件
如果没有指定需要输出的变量,则全部输出,如果指定了,则可以输出相应的变量
 
"""A simple script for inspect checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("file_name", "", "Checkpoint filename")
tf.app.flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")


def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.
  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.
  If `tensor_name` is provided, prints the content of the tensor.
  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    reader = tf.train.NewCheckpointReader(file_name)
    if not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")


def main(unused_argv):
  if not FLAGS.file_name:
    print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
          "[--tensor_name=tensor_to_print]")
    sys.exit(1)
  else:
    print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)

if __name__ == "__main__":
tf.app.run()


3,第三种方法也是TF官方在tool里面给的,称为freeze_graph, 在官方的这个tutorials中有介绍。
一般情况下TF在训练过程中会保存两种文件,一种是保存了变量值的checkpoint文件,另一种是保存了模型的Graph(GraphDef)等其他信息的MetaDef文件,
以.meta结尾Meta,但是其中没有保存变量的值。freeze_graph.py的主要功能就是将chenkpoint中的变量值保存到模型的GraphDef中,使得在一个文件中既
包含了模型的Graph,又有各个变量的值,便于后续操作。当然变量值的保存是可以有选择性的。
在freeze_graph.py中,首先是导入GraphDef (如果有GraphDef则可之间导入,如果没有,则可以从MetaDef中导入). 然后是从GraphDef中的所有nodes中
抽取主模型的nodes(比如各个变量,激活层等)。再用saver从checkpoint中恢复变量的值,以constant的形式保存到抽取的Grap的nodes中,并输出此GraphDef.
GraphDef 和MetaDef都是 基于Google Protocol Buffer 定义的。在GraphDef 中主要以node(NodeDef) 来保存模型。
### TensorFlow 中保存模型的方法及示例 在 TensorFlow 中,保存模型是一个重要的操作,它允许我们将训练好的模型持久化以便后续加载和使用。以下是几种常见的保存模型的方式及其对应的代码实现。 #### 方法一:通过 `tf.train.Saver` 保存 Checkpoint 文件 这种方法适用于 TensorFlow 的早期版本(如 TensorFlow 1.x),也可以用于 TensorFlow 2.x 的兼容模式下。可以通过创建一个 `Saver` 对象来保存模型的状态到指定路径[^3]。 ```python import tensorflow as tf # 定义变量并初始化 v1 = tf.Variable(tf.random.normal([784, 200], stddev=0.35), name="v1") v2 = tf.Variable(tf.zeros([200]), name="v2") # 创建 Saver 对象 saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session() as sess: # 初始化变量 sess.run(tf.compat.v1.global_variables_initializer()) # 执行一些计算... # 将当前会话状态保存至 checkpoint 文件 save_path = saver.save(sess, './model/model.ckpt') print(f'Model saved in path: {save_path}') ``` #### 方法二:Keras API 下的 `.save()` 和 `.load_model()` 对于基于 Keras 构建的模型,在 TensorFlow 2.x 中可以更方便地使用内置函数保存整个模型结构以及权重数据[^2]。 ```python from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense # 建立简单的神经网络模型 model = Sequential([ Dense(64, activation='relu', input_shape=(784,)), Dense(10, activation='softmax') ]) # 编译模型 model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # 训练模型 (此处省略实际的数据输入部分) # 使用 .save 函数存储完整的模型信息 model.save('my_keras_model.h5') print("Model has been successfully saved.") ``` #### 方法三:SavedModel 格式导出模型 这是推荐的一种方式,尤其适合于生产环境下的服务部署需求。它可以将模型转换成独立运行的形式,支持多种框架间的互操作性[^5]。 ```python import tensorflow as tf # 加载已有的 keras 模型或者定义新的模型实例 loaded_model = ... export_dir = 'saved_models/my_model' tf.saved_model.save(loaded_model, export_dir) print(f'Saved model written to: {export_dir}.') ``` 以上三种方法分别展示了不同场景下如何有效地完成 TensorFlow 模型的保存工作。每一种都有其特定的应用场合,请根据具体项目的需求选择合适的技术方案[^1].
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值