更新:
2018.5.4 补充模型保存和恢复的原理,补充了模型保存和恢复的一般流程
版本:tensorflow 1.8
前面一直说的都是没有涉及到模型的保存.一般深度学习的训练是很需要时间的,不可能程序退出了然后又重新训练一次,所以训练好的模型需要保存下来,方便之后的再训练或者是把模型分享给别人都是可以的.模型的保存也可以叫做持久化,一个意思.接下来不啰嗦了,用一个简单的例子来说说模型怎么保存.
本节的所有代码可以在我的GitHub找到:
一.常见类和函数
在这部分先把模型保存和恢复中的常见类和函数列出来,可以暂时先不用详细看他们是怎么用的,这里先混个眼熟。
Ⅰ.tf.train.Saver
保存模型最基本的类就是这个类啦,所以一旦涉及到保存模型的需求,这个类是不可避免的。Saver
类提供了很多方便的操作能够从checkpoints(检查点)保存和恢复变量,
Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a Saver.
Saver类能够通过给定的计数器自动为checkpoint文件编号,这能够在训练模型的不同阶段保存多个检查点文件,举个栗子,你能够使用训练轮数来为你的检查点文件编号,为了避免占满硬盘,Saver还能够自动管理检查点文件,比如保留最新的那N个文件等等。
Saver类提供了一些函数来进行模型的保存和恢复,这里按照平时使用的频率来排序,列出常见的类方法。
save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix=’meta’,write_meta_graph=True,write_state=True,strip_default_attrs=False)
作用就是保存变量,
参数:
sess: 运行当前图的session
save_path: 检查点(checkpoint)文件名
global_step: If provided the global step number is appended to save_path to create the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer.
latest_filename: 可选名称,表示最新的检查点(checkpoints),默认为’checkpoint’.
meta_graph_suffix: Suffix for MetaGraphDef file. Defaults to ‘meta’.
write_meta_graph: Boolean indicating whether or not to write the meta graph file.
write_state: Boolean indicating whether or not to write the CheckpointStateProto.
strip_default_attrs: Boolean. If True, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see Stripping Default-Valued Attributes.
返回:
检查点文件保存的路径。 If the saver is sharded, this string ends with: ‘-?????-of-nnnnn’ where ‘nnnnn’ is the number of shards created. If the saver is empty, returns None.
restore(sess,save_path)
恢复变量,同时要求图运行在这个session里面。
参数:
sess: 用来恢复参数的session
save_path: 保存模型的地址,一般来说,常常使用save()
函数返回的地址后者使用latest_checkpoint()
来得到地址。
Ⅱ.tf.train.latest_checkpoint
tf.train.latest_checkpoint(checkpoint_dir,latest_filename=None)
找到最新保存的checkpoint文件的文件名
参数:
checkpoint_dir: 变量保存的目录
latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames. See the corresponding argument to