由于经常要使用tensorflow进行网络训练,但是在用的时候每次都要把模型重新跑一遍,这样就比较麻烦;另外由于某些原因程序意外中断,也会导致训练结果拿不到,而保存中间训练过程的模型可以以便下次训练时继续使用。
所以练习了tensorflow的save model和load model。
参考于http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/,这篇教程简单易懂!!
1、保存模型
# 首先定义saver类
#max_to_keep=4表示保存4个模型,默认=5
saver = tf.train.Saver(max_to_keep=4)
# 定义会话
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(200):
# 训练
sess.run(train_step)
if epoch % 10 == 0:
print "------------------------------------------------------"
# 保存模型
saver.save(sess, "model/my-model", global_step=epoch)
print("save the model")
注意点:
-
创建saver时,可以指定需要存储的tensor,如果没有指定,则全部保存。
-
创建saver时,可以指定保存的模型个数,利用max_to_keep=4,则最终会保存最新的4个模型(下图中我保存了160、170、180、190step共4个模型)。默认为5个。 如果没有global_step=epoch这句话,则保存最终一个。
-
saver.save()函数里面可以设定global_step,说明是哪一步保存的模型,如果为=epoch,则保存最后4个epoch的模型,如果为50,则保存第50个epoch的模型。
-
程序结束后,会生成四个文件:存储网络结构.meta、存储训练好的参数.data和.index、记录最新的模型checkpoint。
-
如果saver.save(sess, './params', write_meta_graph=False),则不保存网路结构.meta,默认是True。
如:
2、加载模型
def load_model():
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model/my-model-290.meta')
saver.restore(sess, tf.train.latest_checkpoint("model/"))
注意点:
-
首先导入模型结构import_meta_graph,这里填的名字meta文件的名字。
-
然后restore时,是检查checkpoint,所以只填到checkpoint所在的路径下即可,不需要填checkpoint,不然会报错“ValueError: Can’t load save_path when it is None.”。
-
后面根据具体例子,介绍如何利用加载后的模型得到训练的结果,并进行预测。
-
checkpoint文件的内容:
3、不保存模型结构
tf_x = tf.placeholder(tf.float32, x.shape) # input x
tf_y = tf.placeholder(tf.float32, y.shape) # input y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) # hidden layer
o = tf.layers.dense(l, 1) # output layer
loss = tf.losses.mean_squared_error(tf_y, o) # compute cost
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)
sess = tf.Session()