使用tf.trian.NewCheckpointReader(model_dir)
一个标准的模型文件有以下文件, model_dir就是MyModel(没有后缀)
checkpoint
Model.meta
Model.data-00000-of-00001
Model.index
import tensorflow as tf
import pprint # 使用pprint 提高打印的可读性
NewCheck =tf.train.NewCheckpointReader("model")
打印模型中的所有变量
print("debug_string:\n")
pprint.pprint(NewCheck.debug_string().decode("utf-8"))

其中有3个字段, 分别是名字, 数据类型, shape
获取变量中的值
print("get_tensor:\n")
pprint.pprint(NewCheck.get_tensor("D/conv2d/bias"))

在这里插入图片描述
print("get_variable_to_dtype_map\n")
pprint.pprint(NewCheck.get_variable_to_dtype_map())
print("get_variable_to_shape_map\n")
pprint.pprint(NewCheck.get_variable_to_shape_map())


原文:https://blog.youkuaiyun.com/qq_39124762/article/details/82951818
本文介绍如何使用TensorFlow的NewCheckpointReader类来读取模型中的变量,包括变量名称、数据类型、形状及具体值。通过实例演示了如何获取模型的详细信息,适用于模型调试和参数检查。
7286

被折叠的 条评论
为什么被折叠?



