我们通常希望将辛辛苦苦训练好的模型保存起来,以便模型重用。tensorflow的Saver类可以提供此功能。
建Saver类时可传的参数如下所示:
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
‘var_list’指明我们需要保存或重新加载的变量,可以是一个列表或一个字典,比如:
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# 将参数作为字典传到:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}),
其默认值是模型中所有可训练的模型。另一个经常用的参数是max_to_keep,它用于设置保存的模型的个数,默认为5.如果你想每训练一个epoch就保存一次模型,可以将max_to_keep设置为None或0。其他参数说明可参考tensorflow官网说明。
下面用一个实例介绍如何用Saver保存和加载模型:
下面的代码的功能是求解方程,给定样本(x,y),求解a和b。
#coding:utf-8
# @Time :2018/10/24 12:39
# @Author :YY
# @Des :
import tensorflow as tf
import numpy as np
x = tf.placeholder( dtype=tf.float32)
y = tf.placeholder(dtype=tf.float32)
a = tf.Variable(1.0,name='a')#参数初始化
b = tf.Variable(0.0,name='b')
pre = a*x+b #要求解的方程
model_loss = tf.nn.l2_loss(pre-y)
optimizer=tf.train.AdamOptimizer(learning_rate=0.1).minimize(model_loss)
X=np.asarray([1,2,3,4,5,6,7,8,9,10])
print(X)
Y=2*X+3
print(Y)
saver=tf.train.Saver()#创建Saver的对象saver
with tf.Session() as sess:
epoches=200
sess.run(tf.global_variables_initializer())
for i in range(epoches):
loss=0
for i in range(len(X)):
sess.run(optimizer,{x:X[i],y:Y[i]})
loss+=sess.run(model_loss,{x:X[i],y:Y[i]})
loss/=len(X)
print('loss',loss)
saver.save(sess,'model/my_model')#保存模型,指明保存的路径
print(sess.run(a))
saver是创建的Saver的一个对象,Saver的save方法的功能是保存模型,其参数如下:
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True,
strip_default_attrs=False)
其中sess是用于保存模型的session,而save_path指明模型保存的路径。执行save方法后,在保存模型的路径中会自动生成一个新的文件夹‘model’,其中的文件有:
文件说明:
checkpoint:存放最新保存的checkpoint文件的记录
my_model.meta:这是个协议缓冲区,它保存了完整的Tensorflow图形,即所有变量、操作等
my_model.data:包含训练变量的文件,加载时主要加载这里面的变量
my_model.index:
这里值得注意的是:Tensorflow变量仅存在会话中,所以必须在一个会话中保存模型。
重新加载训练好的模型
如果你想重新加载训练好的模型进行测试或继续训练,有两种方式可以实现重用模型的目的:
第一种:你可以重新构建一个模型,通过以下方式将之前保存的模型参数加载进来,记住,如果新建的网络的参数名称与要加载的模型的参数名称相同,则会自动用要加载的模型的参数覆盖你新建的模型的参数
# _*_coding:utf-8 _*_
#@Time :2019-01-14 15:48
#@Author :YY
#Describe :重载模型后,原来变量的值会被覆盖
import tensorflow as tf
import numpy as np
x = tf.placeholder( dtype=tf.float32)
y = tf.placeholder(dtype=tf.float32)
a = tf.Variable(1.0,name='a')#参数初始化
b = tf.Variable(0.0,name='b')
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, 'model/my_model')#加载模型
a=tf.get_default_graph().get_tensor_by_name("a:0")
print(sess.run(a))
如上述代码所示,我们初始化参数a的值为1.0,而当我们加载之前训练好的模型后,a的值变成2.0,即我们已经训练好的值,这说明同名的变量其值会被覆盖。我们还可以直接加载模型
第二种:通过tf.train.import()方法来直接重建模型。我们已经在my_model.meta文件中保存了构建的网络,可以通过以下方式重新创建这个网络(相当于复制粘贴)
saver = tf.train.import_meta_graph('model/my_model.meta')
但import_meta_graph只是将.meta文件中定义的网络附加到当前图中,此时还未将参数的值加载进来。可以通过以下方式将参数加载进来:
# _*_coding:utf-8 _*_
#@Time :2019-01-14 15:48
#@Author :YY
#Describe :直接加载模型,不用再手动创建
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model/my_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model'))
print(sess.run('a:0'))
我们可以看到,变量a已经被加载进来。我们还可以使用graph.get_tensor_by_name()方法来操作这个保存的模型:
a=graph.get_tensor_by_name('a:0')
print(sess.run(a))
至此,Saver类的基本用法介绍完毕,以后会继续补充。