TensorFlow 学习初步- 变量,模型的存储和读取

本文介绍如何在TensorFlow中使用High-level API和Low-level API进行模型的检查点保存及变量的存储与恢复。包括配置检查点的时间频率、数量、路径等;使用tf.train.Saver保存模型中的所有或部分变量;以及如何从磁盘恢复变量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. High level API checkpoints

只针对与 estimator

设置检查点的时间频率和总个数

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

实例化时传递给 estimator 的 config 参数

model_dir 设置存储路径

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

一旦检查点文件存在,TensorFlow 总会在你调用 train() 、 evaluation() 或 predict() 时重建模型

------------------------------------------------------------------------------------------------------------

2.Low level API tf.train.Saver

-------------------------------------------------------------------------------------------------------------

Saver.save 存储 model 中的所有变量

import tensorflow as tf

# 创建变量
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)


# 添加初始化变量的操作
init_op = tf.global_variables_initializer()

# 添加保存和恢复这些变量的操作
saver = tf.train.Saver()

# 然后,加载模型,初始化变量,完成一些工作,并保存这些变量到磁盘中
with tf.Session() as sess:
  sess.run(init_op)
  # 使用模型完成一些工作
  var.op.run()

  # 将变量保存到磁盘中
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)

# tf.get_variable: Gets an existing variable with these parameters or create a new one.
# shape: Shape of the new or existing variable
# initializer: Initializer for the variable if one is created. tf.zeros_initializer 赋值为0 [0 0 0]
saver = tf.train.Saver() # Saver 来管理模型中的所有变量,注意是所有变量
tf.Session() # A class for running TensorFlow operations.
with...as...
#执行 with 后面的语句,如果可以执行则将赋值给 as 后的语句。如果出现错误则执行 with 后语句中的 __exit__
#来报错。类似与 try if,但是更方便

Saver.save 选择性的存储变量

saver = tf.train.Saver({'var2':var2})

-------------------------------------------------------------------------------------------------------------

Saver.restore 加载路径中的所有变量

import tensorflow as tf

tf.reset_default_graph()

# 创建一些变量
var = tf.get_variable("var", shape=[3])

# 添加保存和恢复这些变量的操作
saver = tf.train.Saver()

# 然后,加载模型,使用 saver 从磁盘中恢复变量,并使用变量完成一些工作
with tf.Session() as sess:
  # 从磁盘中恢复变量
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # 检查变量的值
  print("var : %s" % var.eval())

-------------------------------------------------------------------------------------------------------------

inspector_checkpoint 检查存储的变量

加载 inspect_checkpoints

from tensorflow.python.tools import inspect_checkpoint as chkp

打印存储起来的所有变量

chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)

注意其中的参数 all_tensor_names 教程中并未添加这个参数,运行时持续报错 missing

打印制定的变量

chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)

 

转载于:https://my.oschina.net/u/2362565/blog/1802226

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值