原文链接:https://blog.youkuaiyun.com/liangyihuai/article/details/78515913
1、博文回顾:
这篇博客中对于模型加载、保存、加载后的更改、保存加载的文件做了一个比较全面的讲解,主要包括:
- tensorflow模型是什么?到底保存和加载了什么东西?
模型包括图文件(meta graph)和变量文件(ckpt),前者定义了模型的结构(图,节点等),后者保存了图中所有变量的具体值。 - 保存tensorflow模型方法。
...... #创建变量,计算路径等
saver = tf.train.Saver(var_list) #可以通过varlist指定需要保存的变量,默认全部保存。
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
...... #其他计算
saver.save(sess, 'my_test_model',global_step=step) #注意要在tf的Session里面创建保存对象,因为tf变量作用范围在session中;可以通过global_step来指定每经过step步保存一次。
- 加载保存的模型
with tf.Session() as sess:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint(ckpt_file_path))
graph = tf.get_defalut_graph()
restore_tensor = graph.get_tensor_by_name('name:0')#由加载的图和参数名获取对应的tensor
2、另外在看到一篇模型加载文章时,提到加载可以有三种方式
文章链接:https://blog.youkuaiyun.com/u012968002/article/details/79884920
- 直接加载图和参数,方法同上。
- 只加载参数而不加载图
with tf.Session() as sess:
# 程序前面得有 Variable 供 save or restore 才不报错
# 否则会提示没有可保存的变量
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('ckpt_file_path')
......
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess,'./model/model.ckpt-0')
但是这里有一个问题:保存参数时的图和当前定义的图不一定是一个图,那么加载参数有何作用?难道是可以直接使用参数计算?
先上一段代码:
#保存模型
with tf.Session() as sess:
a = tf.Variable(tf.random_normal([1,2]),name='a')
b = tf.Variable(tf.zeros([2]),name= 'b')
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess,'./model/test_load_var')
#加载模型
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./model/')#加载保存模型文件
print(ckpt.model_checkpoint_path)
a = tf.Variable(tf.random_normal([1, 2]), name='a')
c = tf.Variable(tf.zeros([1,2]),name='c')
sess.run(c.initializer)
saver = tf.train.Saver([a])
if ckpt and ckpt.model_checkpoint_path:
print('------------')
print(ckpt.model_checkpoint_path)
saver.restore(sess,'./model/test_load_var')#恢复参数值
print(sess.run(tf.add(a,c)))
几点说明:
恢复的参数必须要在图中有对应的tensor,不论是从恢复的图中获取也好,还是直接在当前图中定义也好,参数的恢复依赖于tensor。
只要明确需要恢复的参数的name就可以直接使用该参数进行计算。但是对参数量巨大的场景并没有太大的意义,与直接恢复图+参数相比,或许通过这样的方式可以以最小的资源占用来恢复模型,但是会大大加大代码的复杂程度。
可以利用单独恢复的参数和新的数据进行retrain。
tensor的name属性是唯一的,而不是对应tensor的变量名。
- 二进制模型加载办法
# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
# 二进制读取模型文件
with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
# 新建GraphDef文件,用于临时载入模型中的图
graph_def = tf.GraphDef()
# GraphDef加载模型中的图
graph_def.ParseFromString(f.read())
# 在空白图中加载GraphDef中的图
tf.import_graph_def(graph_def,name='')
# 在图中获取张量需要使用graph.get_tensor_by_name加张量名
# 这里的张量可以直接用于session的run方法求值了
# 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]
3、代码
'''
保存和加载模型
注意:
保存有metadata和checkpoints。占位符数据不保存,但是保存占位符算子本身。
方法:
tf.train.Saver()
tf.train.Saver().save(sess,model_name)
save方法参数选择:
global_step---每经n次迭代保存,
write_meta_graph---结构只保存一次
max_to_keep---保存版本数量
keep_checkpoint_every_n_hours---每经n小时保存
Saver()类参数:
[]---指定保存变量,未指定全部保存
tf.train.import_meta_gtaph('meta_data_name.meta')
---导入指定网络结构图
tf.train.import_meta_graph('xx.meta').restore(sess,tf.train.latest_checkpoint('.ckpt'))
---恢复模型变量参数
tf.graph.get_tensor_by_name()
---获取保存的占位符和算子
保存为.pb文件
参考:https://blog.youkuaiyun.com/fly_time2012/article/details/82889418
'''
import tensorflow as tf
import os
model_path = './save_and_restore/'
model_name = 'my_test_model'
def save():
w1 = tf.placeholder(dtype=tf.float32,name='w1')
w2 = tf.placeholder(dtype=tf.float32, name='w2')
with tf.variable_scope('test'):
b1 = tf.get_variable('bias',initializer=tf.constant(2.0))
feed_dict = {w1:4,w2:8}
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name='op')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print(sess.run(w4,feed_dict=feed_dict))
saver.save(sess,os.path.join(model_path,model_name),global_step=1000)
def restore0():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(
os.path.join(model_path,model_name+'-1000.meta')
)
saver.restore(sess,tf.train.latest_checkpoint(model_path))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:13.0,w2:17.0}
op = graph.get_tensor_by_name('op:0')
# variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# b = graph.get_tensor_by_name('test/bias:0')
# b = b+1
# sess.run(tf.global_variables_initializer())
# print(b)
print(sess.run(op,feed_dict))
def restore_reeor():
"""不能以同样的方式恢复占位符,会报错:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'w1_1' with dtype float
因为对于一个占位符而已,它所包含的不仅仅是占位符变量的定义部分,还包含数据,而tensorflow不保存占位符的数据部分。应通过graph.get_tensor_by_name的方式获取,然后在feed数据进去"""
w1 = tf.placeholder(dtype=tf.float32, name='w1')
w2 = tf.placeholder(dtype=tf.float32, name='w2')
with tf.Session() as sess:
saver = tf.train.import_meta_graph(
os.path.join(model_path, model_name + '-1000.meta'))
saver.restore(sess, tf.train.latest_checkpoint(model_path))
graph = tf.get_default_graph()
feed_dict = {w1: 13.0, w2: 17.0}
op_to_restore = graph.get_tensor_by_name('op:0')
print(sess.run(op_to_restore, feed_dict))
#
# save()
# restore0()
def restore2():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(
os.path.join(model_path,model_name+'-1000.meta')
)
saver.restore(sess,tf.train.latest_checkpoint(model_path))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:13.0,w2:17.0}
op = graph.get_tensor_by_name('op')
#增加后续操作到当前图,---问题:假设当前程序为训练程序,运行当前程序是否会导致之前加载的变量改变
add_on_op = tf.multiply(op,2)
print(sess.run(add_on_op))
#恢复原来神经网络的一部分参数或者一部分算子,然后利用这一部分参数或者算子构建新的神经网络模型
def createNN_use_restore():
saver = tf.train.import_meta_graph('vgg.meta')
graph = tf.get_default_graph()
fc7 = graph.get_tensor_by_name('fc7:0')
fc7 = tf.stop_gradient(fc7)
fc7_shape = fc7.get_shape().as_list()
new_outputs = 2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3],new_outputs],stddev=0.05))
biases = tf.Variable(tf.constant(0.05,shape=[new_outputs]))
output = tf.matmul(fc7,weights)+biases
pred = tf.nn.softmax(output)