TF Saver 保存/加载训练好模型(网络+参数)的那些事儿

本文介绍了如何在TensorFlow中保存和加载模型。包括使用SavedModel进行语言中立且可恢复的序列化,以及利用元图(meta_graph)和检查点(checkpoint)文件保存和加载模型的具体步骤。

对于神经网络的读取和存储,TF推荐我们使用:When you want to save and load variables, the graph, and the graph's metadata--basically, when you want to save or restore your model--we recommend using SavedModel 推荐使用savedmodel. SavedModel is a language-neutral, recoverable, hermetic serialization format. SavedModel enables higher-level systems and tools to produce, consume, and transform TensorFlow models. TensorFlow provides several mechanisms for interacting with SavedModel, including tf.saved_model APIs, Estimator APIs and a CLI. 。

一般模式:

  • 训练时候保存
# Create a saver.
saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # Saves checkpoint, which by default also exports a meta_graph
        # named 'my-model-global_step.meta'.
        saver.save(sess, 'my-model', global_step=step)

另一个代码加载模型训练好的的网络和参数来测试,或进一步训练

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-model.meta')
  new_saver.restore(sess,tf.train.latest_checkpoint( './')
  ...
  

模型训练完毕之后,你可能需要在产品上使用它。那么tensorflow model是什么?tensorflow模型主要包含网络的结构的定义或者叫graph和训练好的网络结构里的参数。

因此tensorflow model包含2个文件:

a)Meta graph:

使用protocol buffer来保存整个tensorflow graph.例如所有的variables, operations, collections等等。这个文件使用.meta后缀

b) Checkpoint file:

二进制文件包含所有的weights,biases,gradients和其他variables的值。这个文件使用.ckpt后缀,后来变成有2个文件:

mymodel.data-00000-of-00001
mymodel.index

第一个文件.data文件就是保存训练的variables我们将要使用它。

和这些文件一起,tensorflow还有一个文件叫checkpoint用来简单保存最近一次保存checkpoint文件的记录

 

导入训练好的模型,有两件事必须做:

1)创造网络:最直接的方法是再逐层写一遍与原模型一样的网络结构,但是meta文件已经把原始网络保存起来了,因此可以直接导入为你创建网络。saver=tf.train.import_meta_graph('my_model-1000.meta')

2)加载模型:通过saver恢复网络的参数,saver.restore(sess,tf.train.latest_checkpoint('./'))

 

注意:这个checkpoint文件,在导入训练好模型的时候会有如下问题: 当代码知道你指定了log directory 而且在你之前训练胡时候已经在该log目录下生成了checkpoint 模型,它就会直接从你的log目录下去restore checkpoint 文件,而不是the original TF-Slim checkpoint 模型。这一点很多人会忽视。

error like:

2018-01-07 15:17:30.960482: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conduct_encoder/layer9-conv9/E_weights not found in checkpoint

方案:当你改变代码重新训练模型之后,应该删除之前的 log文件。或者你不要有保存log文件的代码。

问题2:当我们保存多个模型在一个目录下面时,checkpoint文件只有一个,默认后来的覆盖最初的checkpoint内容,导致你加载不到自己想要加载的模型,而是加载了最后训练成的模型。

方案:因此如果保存一个训练模型时候,尽量给它一个单另的文件夹,比如

saver(sess, 'First_net/my_model')

 

 

转载于:https://my.oschina.net/liusicong/blog/1604281

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值