【Tensorflow基础】第九课:模型的保存和载入

1.tf.train.Saver()

tf.train.Saver()用于保存和加载模型。

saver=tf.train.Saver()

tf.train.Saver()参数见下:

def __init__(
	self,
	var_list=None,
	reshape=False,
	sharded=False,
	max_to_keep=5,
	keep_checkpoint_every_n_hours=10000.0,
	name=None,
	restore_sequentially=False,
	saver_def=None,
	builder=None,
	defer_build=False,
	allow_empty=False,
	write_version=saver_pb2.SaverDef.V2,
	pad_step_number=False,
	save_relative_paths=False,
	filename=None)

部分参数解释:

  1. var_list:指定要保存和恢复的变量。
  2. max_to_keep:是经常会用到的一个参数。用于设置保存模型的个数(默认为max_to_keep=5,即保存最近的5个模型)。若max_to_keep设置为None或0,则保存所有的模型。
  3. keep_checkpoint_every_n_hours:每n个小时保存一次模型。

1.1.saver.save()

def save(
	self,
	sess,
	save_path,
	global_step=None,
	latest_filename=None,
	meta_graph_suffix="meta",
	write_meta_graph=True,
	write_state=True,
	strip_default_attrs=False,
	save_debug_info=False)

部分参数解释:

  1. sess:Session。
  2. save_path:模型保存路径。例如:saver.save(sess, 'net/my_net.ckpt')
  3. global_step:用来给模型文件名添加数字标记。例如:saver.save(sess, 'my-model', global_step=0),保存得到的模型文件名为:'my-model-0'

1.2.saver.restore()

def restore(self, sess, save_path)

参数解释:

  1. sess:Session。
  2. save_path:模型路径。例如:saver.restore(sess, 'net/my_net.ckpt')

导入模型之前,必须重新再定义一遍变量。但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint(checkpoint_dir, latest_filename=None)
saver.restore(sess,model_file)

2.ckpt模型

使用saver.save()将模型保存为ckpt格式,会生成以下四个文件:

  1. my_net.ckpt.meta:保存了Tensorflow计算图的结构,即网络结构。
  2. my_net.ckpt.indexmy_net.ckpt.data-00000-of-00001:保存了所有变量的取值。
  3. checkpoint:保存了一个目录下所有的模型文件列表。

3.代码地址

  1. 模型的保存和载入

想要获取最新文章推送或者私聊谈人生,请关注我的个人微信公众号:⬇️x-jeff的AI工坊⬇️

个人博客网站:https://shichaoxin.com

GitHub:https://github.com/x-jeff


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值