Tensorflow 模型加载与存储
笔者作为深度学习路上的一枚小菜菜,最近查了一下如何保存模型和加载,这里做一下梳理:
模型训练中无可避免的需要做模型的保存和加载,以便于持久化数据和移植代码,这里就简单的介绍一下基于Tensorflow框架的模型保存与再载:
当保存模型的时候,会生成以下文件:
模型保存的数据分两部分,一部分是图(Graph)或者认为是一系列的运算(Ops);另一部分是数据,即各种变量在运行中的数值;其中.meta文件就是记录着Graph部分,而.index和.data-0000-of-00001后缀的则保存着数据部分;checkpoint为所保存的文件的最新记录,文本文件;
1 什么是Tensorflow model?
当我们训练好一个网络模型时,我们可能需要保存模型和数据甚至进一步部署到应用中。那么什么是Tensorflow 模型呢?Tensorflow模型主要包含两部分:1)网络结构设计或图结构(Graph). 2)训练好的或经过一定训练的网络参数。因此, Tensorflow model 主要有两个文件:
1. 1 Meta Graph
它是一个协议缓冲文件(a protocol buffer),用来保存完整的Tensorflow 图结构,包括所有的变量、操作、集合等,以.meta为文件后缀。
1.2 Checkpoint file
其为二进制文件,涵盖了所有的(在保存时,若没有特殊制定,则保存所有)参数(weights)、偏置(biases)、梯度值(gradients)和其他的变量。此文件以.ckpt为文件后缀。然而,从0.11版本后,Tensorflow就将其分裂成了两个文件:.data-0000-of-0001和.index,如:
.data文件之后的数字是变化的,0.11版本后,其存储了我们的训练参数。
而与.data .meta .index同时生成的还有一个checkpoint文本文件,其作用是为最近保存的模型文件做记录。
小结:
- TTensorflow-0.11版本之前的模型保存后, 文件有三个:
model.meta # Store Graph
model.ckpt # Store variables
checkpoint # Store record latest checkpoint file saved
- TTensorflow-0.11版本之前的模型保存后, 文件有四个:
model.meta # Stroe grap
model.data-* # Store variables
model.index
checkpoint # Store record latest checkpoint file saved
2 保存模型
假设说,你在训练一个图像分类的卷积神经网络。作为一个标准操作,需要观察模型的损失函数值和正确率,当模型收敛的时候,你可能想要停止训练或者只运行一定循环数的数据集。当训练结束的时候,通常保存所有的数据和网络图结构到文件就是下一步操作,在Tnesorflow中,是通过tf.train.Saver()类来进行保存模型等相关操作的。
tip:
需要注意的是,Tnesorflow中的变量值只在session运行期间存在,所以我们必须在会话生存期内进行变量的保存:saver = tf.train.Saver() with tf.Session() as sess: ... saver.save(sess, 'my-test-model')`
这里,sess是一个session对象,'my-test-model’为想要赋予模型文件的名字,以下为完整代码:
import tensorflow as tf
w_1 = tf.Variables(tf.random_normal(shape[2], name='w_1'))
w_2 = tf.Variables(tf.random_normal(shape=[5]), name='w_2')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-test-model')
# 以上代码将保存以下文件:
"""
my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta
checkpoint
"""
如果想要在1000个循环后,保存文件,我们可以给出global_step参数:
saver.save(sess, ‘mt-test-model’, global_step=1000)
此时,保存的文件会变成:
saver.save(sess, 'mt-test-model', global_step=1000)
# files are saved as follows:
"""
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
"""
进一步说,我们没过1000循环都保存文件,所以.meta文件在第一次就被保存(第一个1000循环到达时),实际上我们不需要每到1000整数次循环就在此重写.meta文件,因为其在训练过程中很少改变,则我们可以加入writer_meta_graph=False:
saver.save(sess, 'my-test-model', global_step=step, write_meta_graph=False)
如果想在磁盘上每隔2小时保存最近的4个模型文件的话,可以使用max_to_keep和keep_checkpoint_every_n_hours:
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
tip:
如果没有特殊指定,tf.train.Saver()将保存所有变量。
若指向保存部分参数,则可以指定变量名或集合,当创建saver时,直接传入参数:
import tensorflow as tf
w_1 = tf.Variable(tf.random_normal(shape=[2]), name='w_1')
w_2 = tf.Variable(tf.random_normal(shape=[5]