TensorFlow保存和恢复变量——tf.train.Saver()

本文介绍了如何使用TensorFlow的tf.train.Saver()类来保存和恢复模型变量,包括初始化Saver、保存和恢复变量、选择要保存的变量以及检查ckpt文件中的变量内容。示例代码展示了变量保存的具体步骤,以及如何从checkpoint文件中恢复变量值。

声明:

  1. 参考Tensorflow官方文档
  2. tensorflow当前版本1.1
  3. 更新:现在tensorflow官网有了中文教程,很方便学习了

tf.train.Saver()

tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。

TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。在tf.train.Saver()类初始化时,用于保存和恢复的saverestore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。

saver = tf.train.Saver()

TensorFlow的保存和恢复分为两种:

  • 保存和恢复变量
  • 保存和恢复模型

保存变量

TensorFlow会讲变量保存在二进制checkpoint文件中,这类文件会将变量名称映射到张量值。

下面是保存变量的例子:

  1. 创建变量
  2. 初始化变量
  3. 实例化tf.train.Saver()
  4. 创建Session并保存
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1_name", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2_name", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
    sess.run(init_op)
    # Do some work with the model.
    inc_v1.op.run()
    dec_v2.op.run()
    # Save the variables to disk.
    save_path = saver.save(sess, "model/model.ckpt")  # 返回一个保存路径的字符串
    print("Model saved in path: %s" % save_path)
    
'''output
Model saved in path: model/model.ckpt
'''

保存路径中的文件为:

  • checkpoint:保存当前网络状态的文件
  • model.data-00000-of-00001
  • model.index
  • model.meta:保存Graph结构的文件

可以发现,没有名为 ‘model/model.ckpt’ 的实体文件,其中 ‘model’ 是一个与用户交互的前缀。

关于函数saver.save(),常用的参数就是前三个:

save(
	sess,  # 必需参数,Session对象
	save_path,  # 必需参数,存储路径
	global_step=None,  # 可以是Tensor, Tensor name, 整型数
	latest_filename=None,  # 协议缓冲文件名,默认为'checkpoint',不用管
	meta_graph_suffix='meta',  # 图文件的后缀,默认为'.meta',不用管
	write_meta_graph=True,  # 是否保存Graph
	write_state=True,  # 建议选择默认值True
	strip_default_attrs=False  # 是否跳过具有默认值的节点

恢复变量

从checkpoint文件中提取变量值赋给新定义的变量。

tf.reset_default_graph()
# Create some variables
# !!!variable name必须与保存时的name一致
v1 = tf.get_variable("v1_name", shape=[3])
v2 = tf.get_variable("v2_name", shape=[5])

saver = tf.train.Saver()
with tf.Session() as sess:
    # Restore variables from disk
    saver.restore(sess, "model/model.ckpt")
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 1.  1.  1.]
v2: [-1. -1. -1. -1. -1.]
'''

variable().eval()
eval(session=None)
In a session computes and returns the value of this variable.

选择要保存和恢复的变量

tf.train.Saver()的构造函数传递以下任意内容来轻松指定要保存或加载的名称和变量:

  • 变量列表(要求变量与变量名之间的一一对应)
  • Python字典,其中,key是要使用的名称,value是要管理的变量(通过键值映射自定义变量与变量名之间的对应关系)
tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2_name": v2})
with tf.Session() as sess:
    v1.initializer.run()
    saver.restore(sess, "model/model.ckpt")
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 0.  0.  0.]
v2: [-1. -1. -1. -1. -1.]
'''

variable().initializer
The initializer operation for this variable.

Note

  • 如果需要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的Saver对象。同一个变量可以列在多个Saver对象中。
  • 变量的值只有在Saver.restore()方法运行时才会更改,这些变量不需要初始化。
  • 如果您仅在会话开始时恢复模型变量的子集,则必须为其他变量运行初始化 op。

查看ckpt二进制文件中的变量

我们可以使用 inspect_checkpoint 库快速检查某个检查点的变量。

prints tensors in a checkpoint file.

If no tensor_name is provided, prints the tensor names and shapes in
the checkpoint file.

If tensor_name is provided, prints the content of the tensor.

ArgsDescription
file_nameName of the checkpoint file
tensor_nameName of the tensor in the checkpoint file
all_tensorsBoolean indicating whether to print all tensors
all_tensor_namesBoolean indicating whether to print all tensor names
from tensorflow.python.tools import inspect_checkpoint as ickpt
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=False)
'''output
tensor_name:  v1_name
[ 1.  1.  1.]
'''
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=True)
'''output
tensor_name:  v1_name
[ 1.  1.  1.]
tensor_name:  v2_name
[-1. -1. -1. -1. -1.]
'''
### 回答1: 你好,我理解你是想通过 优快云 学习 C 知识,并且你也使用 ChatGPT 这种类型的 AI。请问你有什么问题需要我回答吗? 关于你提到的问题,我可以解释一下。tf.train.Saver()TensorFlow 中用于保存恢复模型的类。它允许你将 TensorFlow 变量保存到磁盘,并从这些变量恢复。通常在训练过程中,你需要定期保存模型以便稍后恢复或进行推断。tf.train.Saver() 的常见用法如下: ```python import tensorflow as tf # 声明 TensorFlow 变量 weights = tf.Variable(...) biases = tf.Variable(...) # 初始化 TensorFlow 变量 init_op = tf.global_variables_initializer() # 训练模型 # 创建 Saver 对象 saver = tf.train.Saver() # 保存模型 with tf.Session() as sess: sess.run(init_op) # 训练模型 saver.save(sess, 'my_model') # 保存模型到文件 'my_model' ``` 当你需要恢复模型时,可以使用 tf.train.Saver() 的 restore() 方法,如下所示: ```python # 创建 Saver 对象 saver = tf.train.Saver() # 加载模型 with tf.Session() as sess: saver.restore(sess, 'my_model') # 从文件 'my_model' 恢复模型 # 执行推断 ``` 希望这可以回答你的问题。如果你有其他问题,请随时提出。 ### 回答2: tf.train.Saver()是一个用于保存、加载以及管理模型参数的类。在TensorFlow中,我们通常使用tf.train.Saver()保存恢复模型的变量。 在TensorFlow中,模型的训练过程通常会生成一些变量,比如神经网络的权重偏差。而tf.train.Saver()类提供了一种方法,可以将这些变量保存到文件中。通过调用tf.train.Saver().save()方法,可以将模型的变量保存在一个checkpoint文件中,以供将来使用。 除了保存模型变量tf.train.Saver()还可以用于加载已保存的模型变量。通过调用tf.train.Saver().restore()方法,可以从checkpoint文件中载入模型的变量,并且将其赋值给指定的TensorFlow变量。这样,我们就可以在程序中使用这些已保存的模型变量,而无需重新训练模型。 另外,tf.train.Saver()还具备一些其他的功能,比如可以指定保存加载的变量以及保存恢复模型的过程是否应该包含模型的图结构。 总结起来,tf.train.Saver()是一个用于保存、加载管理TensorFlow模型参数的类。它提供了保存恢复模型变量的功能,可以确保模型的训练结果可以方便地在之后的使用中进行加载重用。 ### 回答3: tf.train.Saver()tensorflow中用于模型参数的保存恢复的类。 在tensorflow中,模型参数通常是在训练过程中不断更新的,而为了保留训练过程中的模型参数,我们可以使用tf.train.Saver()类来保存这些参数。tf.train.Saver()类提供了保存恢复模型的方法,可以将模型的参数保存到文件中,并在需要的时候恢复这些参数。 保存模型参数是通过调用tf.train.Saver()类的save()方法实现的。save()方法需要传入一个session一个保存路径,表示将当前模型的参数保存到指定的路径下。保存的参数可以是全局变量、权重、偏置等等。 恢复模型参数是通过调用tf.train.Saver()类的restore()方法实现的。restore()方法需要传入一个session一个保存路径,表示从指定的路径中恢复模型的参数。恢复参数时,tensorflow会自动判断模型的参数是否与当前模型的参数匹配,如果匹配,则恢复参数;如果不匹配,则会抛出异常。 使用tf.train.Saver()类可以实现模型的断点续训。即在训练过程中,可以将当前的模型参数保存到文件中。如果训练过程中发生意外,可以在恢复训练时,加载之前保存的模型参数,从上一次中断的地方继续训练。 总之,tf.train.Saver()tensorflow中用于保存恢复模型参数的重要工具,它提供了方便的接口,使得我们可以灵活地管理模型参数,实现模型的保存恢复断点续训。
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值