参考:
https://blog.youkuaiyun.com/liuxiaodong400/article/details/83421164
1. 实例化对象
saver=tf.train.Saver(max_to_keep=1)
class Saver(object):
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
max_to_keep:表明保存的最大checkpoint文件数,当一个新文件创建的时候,旧文件就会被删除掉,如果值为none或,表示保存所有的checkpoint文件,默认值为5,也就是保存最近的5个checkpoint文件.
keep_checkpoint_every_n_hours: 除了保存最近的checkpoint文件,每隔N小时保存一个checkpoint文件.
2.保存训练过程中的或者训练好的,模型图及权重系数
2.1 前面创建好saver对象后,就可以保存训练好的模型了
saver.save(sess=sess,save_path=model_save_path,global_step=step)
第一个参数sess=sess,会话名字
第二个参数,save_path=model_save_path,权重参数保存的路径和文件名
第三个参数,global_step=step,将训练的的次数作为后缀加入到模型名字中
2.2,一次saver.save()后可以在文件夹中看到新增的4个文件
每调用一次保存操作会创建后3个文件,并创建一个检查点checkpoint文件
简单理解就是权重等参数被保存到.ckpt.data文件中,以字典的形式
.ckpt.index,应该是内部需要的某种索引来正确映射前两个文件
.ckpt.meta,文件保存图和元数据,可以使用tf.train.import_meta_graph加载
3. 重载模型参数,继续训练或用于测试
saver.restore(sess=sess, save_path=model_save_path)
第一个参数sess=sess,会话名字
第二个参数,save_path=model_save_path,权重参数保存的路径和文件名