Tensorflow模型变量的保存

本文深入讲解了tf.train.Saver的使用方法,包括如何通过字典或列表指定变量,自动为checkpoint文件编码,以及如何利用计数器保存不同训练迭代轮的checkpoints文件。

1、tf.train.Saver(var_list=None, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0) 。该方法用于保存恢复变量。该类创建的对象可以“向checkpoints文件中保存变量”或“从checkpoints文件中恢复变量”。checkpoints文件为二进制编码文件。内容为变量名和tensor组成的键值对。

默认保存整个graph(var_list=None)。var_list指定了要保存恢复的变量,其形式可以为字典(dict)或列表(list)。

若var_list为dict,则在checkpoint文件中,键将被用于保存或恢复变量。若var_list为list形式,则变量的操作名(op.name)将被用作键,变量作为值,构成字典来保存恢复

 

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

Saver可以用计数器自动为checkpoint文件进行编码。这使你在训练模型时可以在不同的训练迭代轮(steps)保存不同编码的checkpoints文件。例如你可以使用训练迭代轮数来为checkpoint文件进行编码。默认参数max_to_keep=5表示只保存最近的5份checkpoints文件,默认参数keep_checkpoint_every_n_hours=10000.0表示每隔10000小时保存一个checkpoint文件,这样可以避免文件占满桌面。例

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

 saver.save方法会返回文件的保存路径。

 
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符  | 博主筛选后可见
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值