tensorflow模型的保存与恢复
今天不想学习,缓一缓呀,上一章用到了tensorflow模型的保存,今天详细的再学一学。
大纲:
-
简单的保存与恢复
-
更为细化的保存与恢复方法
简单的保存与恢复
Saver是tensorflow中的保存类,其常用方法为save和restore 分别用于模型的保存和恢复。
一般的保存方式为:
save_path = "..." # 自行填入保存地址
saver = tf.train.Saver()
Sess = tf.Session()
...
# 模型定义
...
saver.save(sess, save_path)
结果会有如下几个文件:
其中,
checkpoint为检查点文件,记录存储文件的名称
.ckpt.index存储权重目录
.ckpt.meta存储整个数据流图
.ckpt.data存储所有权重
同时,也可以多次通过设置global_step多次保存模型:
for i in range(100):
saver.save(sess, save_path, global_step=i)
但是为了防止占据太多空间,tensorflow仅仅会保存最近的5个模型文件,删除其他模型。
模型的恢复也很简单:
save_path = "..."
saver = tf.train.Saver()
sess = tf.Session()
...
# 模型定义
...
saver.restore(sess, save_path)
更为细化的模型保存与恢复
当进行整个模型的存储和恢复时,上述方法已经够用,但需要对特定的权重进行恢复和处理时,首先需要为每个变量进行命名:
with tf.variable_scope("var"):
a = tf.Variable(tf.random_normal([1]), name="a_val")
x = tf.placeholder(tf.float32, name="x_placeholder")
print(a)
print(x)
其中==tf.variable_scope(“var”)==的作用是定义一个命名域
输出a和x:
<tf.Variable 'var/a_val:0' shape=(1,) dtype=float32_ref>
Tensor("x_placeholder:0", dtype=float32)
可以发现a的名字为== ‘var/a_val:0’==,var为其命名空间,0是为了进行参数的复用而使用的符号。
然后就可以恢复并处理特定的值啦:
saver = tf.train.import_meta_graph('./model/AlexNetModel.ckpt.meta')
graph = tf.get_default_graph()
a = graph.get_tensor_by_name('var/a_val:0')
x = graph.get_tensor_by_name('x_placeholder:0')
print(a)
print(x)
输出:
<tf.Variable 'var/a_val:0' shape=(1,) dtype=float32_ref>
Tensor("x_placeholder:0", dtype=float32)
为了得到其值,则需要在对话中完成:
saver = tf.train.import_meta_graph('./model/AlexNetModel.ckpt.meta')
graph = tf.get_default_graph()
a = graph.get_tensor_by_name('var/a_val:0')
with tf.Session() as sess:
saver.restore(sess, './model/AlexNetModel.ckpt') # 恢复权重
result = sess.run(a)
print(result)
今天划了一天,明天一定好好学习!
8说了,好久没打游戏了,晚上回去爽一爽嘻嘻。