tf里面提供模型保存的是tf.train.Saver()模块。模型保存,先要创建一个Saver对象:如
saver=tf.train.Saver()
max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。
如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可。
Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
第一个参数sess。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。
isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
train_steps:表示训练的次数,例子中使用100
checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径
使用Saver.save()方法保存模型:
- sess:表示当前会话,当前会话记录了当前的变量值
- checkpoint_dir + 'model.ckpt':表示存储的文件名
- global_step:表示当前是第几步