1. High level API checkpoints
只针对与 estimator
设置检查点的时间频率和总个数
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
)
实例化时传递给 estimator 的 config 参数
model_dir 设置存储路径
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
一旦检查点文件存在,TensorFlow 总会在你调用 train()
、 evaluation()
或 predict()
时重建模型
------------------------------------------------------------------------------------------------------------
2.Low level API tf.train.Saver
-------------------------------------------------------------------------------------------------------------
Saver.save 存储 model 中的所有变量
import tensorflow as tf
# 创建变量
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)
# 添加初始化变量的操作
init_op = tf.global_variables_initializer()
# 添加保存和恢复这些变量的操作
saver = tf.train.Saver()
# 然后,加载模型,初始化变量,完成一些工作,并保存这些变量到磁盘中
with tf.Session() as sess:
sess.run(init_op)
# 使用模型完成一些工作
var.op.run()
# 将变量保存到磁盘中
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)
# tf.get_variable: Gets an existing variable with these parameters or create a new one.
# shape: Shape of the new or existing variable
# initializer: Initializer for the variable if one is created. tf.zeros_initializer 赋值为0 [0 0 0]
saver = tf.train.Saver() # Saver 来管理模型中的所有变量,注意是所有变量
tf.Session() # A class for running TensorFlow operations.
with...as...
#执行 with 后面的语句,如果可以执行则将赋值给 as 后的语句。如果出现错误则执行 with 后语句中的 __exit__
#来报错。类似与 try if,但是更方便
Saver.save 选择性的存储变量
saver = tf.train.Saver({'var2':var2})
-------------------------------------------------------------------------------------------------------------
Saver.restore 加载路径中的所有变量
import tensorflow as tf
tf.reset_default_graph()
# 创建一些变量
var = tf.get_variable("var", shape=[3])
# 添加保存和恢复这些变量的操作
saver = tf.train.Saver()
# 然后,加载模型,使用 saver 从磁盘中恢复变量,并使用变量完成一些工作
with tf.Session() as sess:
# 从磁盘中恢复变量
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# 检查变量的值
print("var : %s" % var.eval())
-------------------------------------------------------------------------------------------------------------
inspector_checkpoint 检查存储的变量
加载 inspect_checkpoints
from tensorflow.python.tools import inspect_checkpoint as chkp
打印存储起来的所有变量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)
注意其中的参数 all_tensor_names 教程中并未添加这个参数,运行时持续报错 missing
打印制定的变量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)