在本教程中,我将会解释:
- TensorFlow模型是什么样的?
- 如何保存TensorFlow模型?
- 如何恢复预测/转移学习的TensorFlow模型?
- 如何使用导入的预先训练的模型进行微调和修改?
这个教程假设你已经对神经网络有了一定的了解。如果不了解的话请查阅相关资料。
1. 什么是TensorFlow模型?
训练了一个神经网络之后,我们希望保存它以便将来使用。那么什么是TensorFlow模型?Tensorflow模型主要包含我们所训练的网络参数的网络设计或图形和值。因此,Tensorflow模型有两个主要的文件:
-
Meta graph:
这是一个协议缓冲区,它保存了完整的Tensorflow图形;即所有变量、操作、集合等。该文件以.meta作为扩展名。 -
Checkpoint file:
这是一个二进制文件,它包含了所有的权重、偏差、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。然而,Tensorflow从0.11版本中改变了这一点。现在,我们有两个文件,而不是单个.ckpt文件:
mymodel.data-00000-of-00001
mymodel.index
.data文件是包含我们训练变量的文件,我们待会将会使用它。
与此同时,Tensorflow也有一个名为checkpoint的文件,它只保存的最新保存的checkpoint文件的记录。
因此,为了总结,对于大于0.10的版本,Tensorflow模型如下:
在0.11之前的Tensorflow模型仅包含三个文件:
inception_v1.meta
inception_v1.ckpt
checkpoint
现在我们已经知道了Tensorflow模型的样子,接下来我们来看看TensorFlow是如何保存模型的。
2. 保存TensorFlow模型
比方说,你正在训练一个卷积神经网络来进行图像分类。作为一种标准的练习,你要时刻关注损失和准确率。一旦看到网络已经收敛,我们可以暂停模型的训练。在完成培训之后,我们希望将所有的变量和网络结构保存到一个文件中,以便将来使用。因此,在Tensorflow中,我们希望保存所有参数的图和值,我们将创建一个tf.train.Saver()类的实例。
saver = tf.train.Saver()
请记住,Tensorflow变量仅在会话中存在。因此,您必须在一个会话中保存模型,调用您刚刚创建的save方法。
saver.save(sess, 'my-test-model')
这里,sess是会话对象,而’my-test-model’是保存的模型的名称。让我们来看一个完整的例子:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')
# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint
如果我们想在1000次迭代之后保存模型,我们在save中添加global_step参数:
saver.save(sess, 'my_test_model',global_step=1000)
这将会将’-1000’追加到模型名称,并创建以下文件:
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.da