Tensorflow的模型持久化主要有两种方式,一种是保存为CKPT文件(通过tf.train.Saver()类),一种是保存为pb文件(通过graphutil)。
通过checkpoint
先来看看通过ckpt文件持久化模型的例程:
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0),
name='v1')
v2 = tf.Variable(tf.constant(2.0),
name='v2')
result = tf.add(v1,v2)
init = tf.global_variabels_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess,
'path/to/model/model.ckpt')
在指定路径下,当前的session被保存在来三个文件下,像这样:
ckpt.data文件是保存模型中所有参数的值
ckpt.index文件保存参数对应的变量名
ckpt.meta保存对应的网络结构,也就是graph
要恢复保存的模型,首先需要导入图的结构(tf.train.import_meta_graph()),然后再恢复各参数权重值(tf.train.Saver.restore()),,这样就得到了保存的整个模型,可以通过name得到模型中任意tensor(saver.restore(sess,tensor_name))。例程如下:
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('ckpt/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('ckpt/'))
print(sess.run('conv1/conv1_w1:0'))
也可以不专门导入graph,直接实例化一个Saver对象,把ckpt的路径给这个对象,让它自己找图和权重,然后恢复:
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint('path/to/ckpt')
saver.restore(sess,ckpt)
print(sess.run('conv1/conv_w1:0'))
通过.pb文件
前面介绍了通过tf.train.Saver()持久化模型的方法,我们可以看到每次保存模型会生成好几个文件,权重和图结构分布在两个不同的文件中,实际应用中会带来一些不便。graph包括了图中各个节点和常量,如果把所有变量值变成常量,就可以成为图的一部分,把整个模型保存在一个文件中。这样还有一个好处,就是可以选择图中指定的节点保存下来,不用像Saver()一样保存一切的节点和变量。
方法:1.利用convert_variables_to_constants函数将sess中选定的节点固定为常量,并放到序列化图模型中(graph_def)
2.利用gfile.GFlile函数创建一个.pb文件
3.将放入了变量的graph_def写入.pb文件中
with tf.Session() as sess:
graph_def = tf.get_default_graph().as_graph_def()
graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['opt_name'])
with gfile.GFile('model.pb','wb') as f:
f.write(graph_def.SerializeToString())
读取.pb文件中保存的模型:
1.以二进制形式打开保存的.pb文件
2.通过ParseFromString()函数读取.pb中的序列化图模型信息导入当前默认graph_def中
3.通过tf.import_graph_def()导入图
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
graph_def = tf.get_default_graph().as_graph_def()
with gfile.GFile('pb/model.pb','rb') as f:
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def,return_elements=['conv1/conv_w1:0'])
print(sess.run(output))
因为.pb的图模型中保存的节点都是常数,我们常用它来保存已经训练好的网络模型,导出后可以直接进行inference。在transfer learning中,输入就是bottleneck节点inference的结果,因此我们常用.pb文件保存好的预训练模型做transfer learning 。