TensorFlow学习(十二):模型的保存与恢复(上)基本操作

本文介绍了TensorFlow中模型的保存与恢复,包括tf.train.Saver类、tf.train.latest_checkpoint函数等关键方法的使用。通过示例展示了如何保存和恢复模型,以及模型文件的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

更新:

2018.5.4 补充模型保存和恢复的原理,补充了模型保存和恢复的一般流程
版本:tensorflow 1.8

前面一直说的都是没有涉及到模型的保存.一般深度学习的训练是很需要时间的,不可能程序退出了然后又重新训练一次,所以训练好的模型需要保存下来,方便之后的再训练或者是把模型分享给别人都是可以的.模型的保存也可以叫做持久化,一个意思.接下来不啰嗦了,用一个简单的例子来说说模型怎么保存.
本节的所有代码可以在我的GitHub找到:

一.常见类和函数

在这部分先把模型保存和恢复中的常见类和函数列出来,可以暂时先不用详细看他们是怎么用的,这里先混个眼熟。

Ⅰ.tf.train.Saver

保存模型最基本的类就是这个类啦,所以一旦涉及到保存模型的需求,这个类是不可避免的。Saver类提供了很多方便的操作能够从checkpoints(检查点)保存和恢复变量,
Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a Saver.
Saver类能够通过给定的计数器自动为checkpoint文件编号,这能够在训练模型的不同阶段保存多个检查点文件,举个栗子,你能够使用训练轮数来为你的检查点文件编号,为了避免占满硬盘,Saver还能够自动管理检查点文件,比如保留最新的那N个文件等等。
Saver类提供了一些函数来进行模型的保存和恢复,这里按照平时使用的频率来排序,列出常见的类方法。

save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix=’meta’,write_meta_graph=True,write_state=True,strip_default_attrs=False)
作用就是保存变量,

参数:

sess: 运行当前图的session
save_path: 检查点(checkpoint)文件名
global_step: If provided the global step number is appended to save_path to create the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer.
latest_filename: 可选名称,表示最新的检查点(checkpoints),默认为’checkpoint’.
meta_graph_suffix: Suffix for MetaGraphDef file. Defaults to ‘meta’.
write_meta_graph: Boolean indicating whether or not to write the meta graph file.
write_state: Boolean indicating whether or not to write the CheckpointStateProto.
strip_default_attrs: Boolean. If True, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see Stripping Default-Valued Attributes.

返回:
检查点文件保存的路径。 If the saver is sharded, this string ends with: ‘-?????-of-nnnnn’ where ‘nnnnn’ is the number of shards created. If the saver is empty, returns None.

restore(sess,save_path)
恢复变量,同时要求图运行在这个session里面。
参数:
sess: 用来恢复参数的session
save_path: 保存模型的地址,一般来说,常常使用save() 函数返回的地址后者使用latest_checkpoint() 来得到地址。

Ⅱ.tf.train.latest_checkpoint

tf.train.latest_checkpoint(checkpoint_dir,latest_filename=None)
找到最新保存的checkpoint文件的文件名

参数:
checkpoint_dir: 变量保存的目录
latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames. See the corresponding argument to

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值