tf.train.Saver

tf.train.Saver是TensorFlow中用于保存和恢复模型变量的工具,支持自动编号和清理checkpoint文件。Saver可以通过全局步数对文件命名,并控制保留的checkpoint数量或间隔。它允许指定要保存的变量列表,并提供了保存和恢复模型的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Saver类用来保存和恢复变量

Saver类增加了保存和恢复变量到checkpoints的操作。它还提供了运行这些操作的便利方法。

Checkpoints是专有格式的二进制文件,将变量名称映射到张量值。检查Checkpoints文件内容的最佳方式是使用Saver加载它。

Saver可以用计数器自动编号checkpoint文件,这可以让你在训练模型时,在不同的步骤中保留多个checkpoint。例如,你可以使用训练步数对checkpoint文件名进行编号。为了避免填写磁盘,savers自动管理checkpoint文件。例如,只保留N个最近的checkpoint文件,或者每N个训练时间内保留一个checkpoint文件。
You number checkpoint filenames by passing a value to the optional global_step argument to save():
你可以通过向传值对
通过将值传递给 save()可选的global_step参数来为checkpoint文件编号:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
saver.save(sess, 'model', global_step=0) ==> filename: 'model-0'
...
saver.save(sess, 'model', global_step=1000) ==> filename: 'model-1000'

Additionally, optional arguments to the Saver() constructor let you control the proliferation of checkpoint files on disk
另外,Saver()构造函数的可选参数可以控制磁盘上checkpoint文件的扩散

  • max_to_keep: 要保留的最近checkpoint 文件的最大数量。 创建新文件时,会删除较旧的文件。如果是None或0,则保留所有checkpoint 文件。默认为5(即保留最新的5个检查点文件。)
  • keep_checkpoint_every_n_hours: 除了保留最新的max_to_keep检查点文件之外,你也可能需要在每N小时的训练中保留一个检查点文件。如果你想要在长时间的训练session中分析模型的进展情况,这将非常有用。例如,keep_checkpoint_every_n_hours = 2可确保每2个小时的训练中保留一个checkpoint文件。默认值为10,000小时,这有效地禁用了该功能。

请注意,你仍然必须调用save()方法来保存模型。将这些参数传递给构造函数不会自动保存变量。

定期保存的训练代码如下:

...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

还有一个’checkpoint’文件,存的是最近的几个检查点文件。

model_checkpoint_path: "/dockerData/cnn/cnn-text-classification-tf/runs/1490200383/checkpoints/model-30000"
all_model_checkpoint_paths: "/dockerData/cnn/cnn-text-classification-tf/runs/1490200383/checkpoints/model-29600"
all_model_checkpoint_paths: "/dockerData/cnn/cnn-text-classification-tf/runs/1490200383/checkpoints/model-29700"
all_model_checkpoint_paths: "/dockerData/cnn/cnn-text-classification-tf/runs/1490200383/checkpoints/model-29800"
all_model_checkpoint_paths: "/dockerData/cnn/cnn-text-classification-tf/runs/1490200383/checkpoints/model-29900"
all_model_checkpoint_paths: "/dockerData/cnn/cnn-text-classification-tf/runs/1490200383/checkpoints/model-30000"

如果创建几个saver,可以在调用save()时为协议缓冲区文件指定不同的文件名

tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=2, pad_step_number=False) {:#Saver.init}

创建一个saver
构造函数添加了保存和恢复变量的操作。

var_list specifies the variables that will be saved and restored. It can be passed as a dict or a list:
var_list指定将被保存和还原的变量。它可以作为dict或者列表传递:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
# 如果不指名参数,默认的是所有的变量
saver = tf.train.Saver()

Args:
主要参数:

  • var_list: 一个变量列表,或者是词典,如果没有,则默认为所有可保存对象的列表。
  • reshape: 如果为True,则允许从具有不同形状的变量从检查点恢复参数。
  • sharded: 如果为True,则将检查点分片,每个设备一个。
  • max_to_keep: 最近保留的检查点的最大数量。默认为5。
  • keep_checkpoint_every_n_hours: 保持检查点多久。默认为10,000小时。
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)

保存变量

此方法运行由构造函数添加的用于保存变量的ops。它需要一个会话,其中启动了图形。要保存的变量也必须已初始化。

该方法返回新创建的检查点文件的路径。该路径可以直接传递给对restore()的调用。

参数:

  • sess: 保存变量的会话。
  • save_path: String. 检查点文件名的路径。如果保护程序被分片,则这是分片的检查点文件名的前缀。
  • global_step: 如果提供,则将全局步数附加到save_path以创建检查点文件名。可选参数可以是Tensor,Tensor名称或整数。
  • latest_filename: 将包含最新检查点文件名列表的协议缓冲区文件的可选名称。保存在与检查点文件相同的目录中的文件由保存程序自动管理以跟踪最近的检查点。默认为“checkpoint”。

返回:
一个字符串:保存变量的路径。如果保护程序被分片,则该字符串以:’-nnnnn’结尾,其中’nnnnn’是创建的分片数。如果保护程序为空,则返回None。

tf.train.Saver.restore(sess, save_path)

恢复保存的变量

此方法运行由构造函数添加的用于恢复变量的ops。它需要一个会话,其中启动了图形。要还原的变量不必被初始化,因为恢复本身就是初始化变量的一种方式。

save_path参数通常是先前从save()调用返回的值,或者是调用latest_checkpoint()。

例子:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess,'model',global_step=n)
    saver.restore(sess,'model')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值