第二阶段-tensorflow程序图文详解(五) Saving and Restoring

本文详细介绍了如何在TensorFlow中保存和恢复变量及模型。tf.train.Saver类用于管理并保存或恢复模型,而SavedModel是推荐的用于保存和加载模型的序列化格式,支持跨语言和恢复整个模型。文章涵盖了如何创建Saver对象,保存和恢复变量,以及如何选择要保存和恢复的变量。同时,文章还概述了SavedModel的概念、API用法以及CLI工具的使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

This document explains how to save and restore variables and models.
这篇blog讨论变量,模型的存储和加载。

1,Saving and restoring variables

A TensorFlow variable provides the best way to represent shared, persistent state manipulated by your program. (See Variables for details.) This section explains how to save and restore variables. Note that Estimators automatically saves and restores variables (in the model_dir).
variables提供最好的共享,持久化状态。注意Estimators自动化保存,恢复变量。

The tf.train.Saver class provides methods for saving and restoring models. The tf.train.Saver constructor adds save and restore ops to the graph for all, or a specified list, of the variables in the graph. The Saver object provides methods to run these ops, specifying paths for the checkpoint files to write to or read from.
tf.train.Saver类提供了保存和恢复模型的方法。tf.train.Saver构造器对于graph,或者graph中的list,variables。保存对象提供这些操作,指定checkpoint的路径,去读写。

The saver will restore all variables already defined in your model. If you’re loading a model without knowing how to build its graph (for example, if you’re writing a generic program to load models), then read the Overview of saving and restoring models section later in this document.

TensorFlow saves variables in binary checkpoint files that, roughly speaking, map variable names to tensor values.
TensorFlow将变量保存在二进制检查点文件中,粗略地说,它将变量名称映射为张量值。

1.1 Saving variables

Create a Saver with tf.train.Saver() to manage all variables in the model. For example, the following snippet demonstrates how to call the tf.train.Saver.save method to save variables to a checkpoint file:
创建一个Saver去管理所有的variables。下面代码片段,演示保存variables为checkpoint文件。

#创建一些variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)


# 这个操作用来初始化所有variables
init_op = tf.global_variables_initializer()

# 这个操作用来保存所有variables
saver = tf.train.Saver()

# 之后,运行这个模型,并保存variables到硬盘上。
with tf.Session() as sess:
  sess.run(init_op)
  #让这个模型工作起来
  inc_v1.op.run()
  dec_v2.op.run()
  # 讲variables保存到硬盘
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

1.2,Restoring variables

The tf.train.Saver object not only saves variables to checkpoint files, it also restores variables. Note that when you restore variables from a file you do not have to initialize them beforehand. For example, the following snippet demonstrates how to call the tf.train.Saver.restore method to restore variables from a checkpoint file:
恢复模型示例

tf.reset_default_graph()

# 创建一些 variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# 添加一个saver
saver = tf.train.Saver()

# 之后,开始保存模型
with tf.Session() as sess:
  # 从硬盘中加载ckpt文件
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # 将variables值打印出。
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

1.3,Choosing which variables to save and restore

If you do not pass any arguments to tf.train.Saver(), the saver handles all variables in the graph. Each variable is saved under the name that was passed when the variable was created.
如果不传递任何参数给tf.train.Saver(),保存器将处理图中的所有变量。 每个变量都保存在创建变量时传递的名称下。

It is sometimes useful to explicitly specify names for variables in the checkpoint files. For example, you may have trained a model with a variable named “weights” whose value you want to restore into a variable named “params”.
显式指定检查点文件中变量的名称有时很有用。 例如,您可能已经训练了一个名为“weights”的变量,该变量的值要恢复到名为“params”的变量中。

It is also sometimes useful to only save or restore a subset of the variables used by a model. For example, you may have trained a neural net with five layers, and you now want to train a new model with six layers that reuses the existing weights of the five trained layers. You can use the saver to restore the weights of just the first five layers.
保存或恢复模型使用的变量的子集,有时也是有用的。 例如,您可能已经训练了一个五层的神经网络,现在您要训练一个六层的新模型,重新使用五个训练层的现有权重。 您可以使用保存程序恢复前五层的权重。

You can easily specify the names and variables to save or load by passing to the tf.train.Saver() constructor either of the following:

  1. A list of variables (which will be stored under their own names).
  2. A Python dictionary in which keys are the names to use and the
    values are the variables to manage.
    一个Python字典,其中键是要使用的名称和
    值是要管理的变量。
    Continuing from the save/restore examples shown earlier:
tf.reset_default_graph()
# 创建一些 variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

# 添加一个存储,恢复的操作,将v2,映射到key为v2
saver = tf.train.Saver({
  
  "v2": v2})

# 使用saver对象
with tf.Session() as sess:
  # 初始化v1变量,当不会保存。
  v1.initializer.run()
  saver.restore(sess, "/tmp/model.ckpt")

  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Notes:

  1. You can create as many Saver objects as you want if you need to save
    and restore different subsets of the model variables. The same
    variable can be listed in multiple saver objects; its value is only
    changed when the Saver.restore() method is run.
    如果需要保存和恢复模型变量的不同子集,可以根据需要创建任意多个Saver对象。 同一个变量可以列在多个保存对象中; 它的值只有在Saver.restore()方法运行时才会改变。

  2. If you only restore a subset of the model variables at the start of
    a session, you have to run an initialize op for the other variables.
    See tf.variables_initializer for more information.
    如果只在会话开始时恢复模型变量的子集,则必须为其他变量运行初始化操作。 有关更多信息,请参阅tf.variables_initializer。

  3. To inspect the variables in a checkpoint, you can use the
    inspect_checkpoint library, particularly the
    print_tensors_in_checkpoint_file function.

    要检查检查点中的变量,可以使用inspect_checkpoint库,特别是print_tensors_in_checkpoint_file函数。

  4. By default, Saver uses the value of the tf.Variable.name property
    for each variable. However, when you create a Saver object, you may
    optionally choose names for the variables in the checkpoint files.
    默认情况下,Saver使用每个变量的tf.Variable.name属性的值。 但是,当您创建一个Saver对象时,您可以选择为检查点文件中的变量选择名称。

2,Overview of saving and restoring models


When you want to save and load variables, the graph, and the graph’s metadata–basically, when you want to save or restore your model–we recommend using SavedModel. SavedModel is a language-neutral, recoverable, hermetic serialization format. SavedModel

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值