计算机视觉——【tensorflow入门】模型的存储与加载

本文介绍了Tensorflow模型的保存与加载,包括Meta Graph和Checkpoint文件的作用。通过tf.train.Saver()类进行模型保存,加载时需先重建网络结构,再加载数据文件。在加载后可以进行预测、微调或进一步训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

笔者作为深度学习路上的一枚小菜菜,最近查了一下如何保存模型和加载,这里做一下梳理:
模型训练中无可避免的需要做模型的保存和加载,以便于持久化数据和移植代码,这里就简单的介绍一下基于Tensorflow框架的模型保存与再载:

当保存模型的时候,会生成以下文件:
在这里插入图片描述
模型保存的数据分两部分,一部分是图(Graph)或者认为是一系列的运算(Ops);另一部分是数据,即各种变量在运行中的数值;其中.meta文件就是记录着Graph部分,而.index和.data-0000-of-00001后缀的则保存着数据部分;checkpoint为所保存的文件的最新记录,文本文件;

1 什么是Tensorflow model?

当我们训练好一个网络模型时,我们可能需要保存模型和数据甚至进一步部署到应用中。那么什么是Tensorflow 模型呢?Tensorflow模型主要包含两部分:1)网络结构设计或图结构(Graph). 2)训练好的或经过一定训练的网络参数。因此, Tensorflow model 主要有两个文件:

1. 1 Meta Graph

它是一个协议缓冲文件(a protocol buffer),用来保存完整的Tensorflow 图结构,包括所有的变量、操作、集合等,以.meta为文件后缀。

1.2 Checkpoint file

其为二进制文件,涵盖了所有的(在保存时,若没有特殊制定,则保存所有)参数(weights)、偏置(biases)、梯度值(gradients)和其他的变量。此文件以.ckpt为文件后缀。然而,从0.11版本后,Tensorflow就将其分裂成了两个文件:.data-0000-of-0001和.index,如:
在这里插入图片描述
.data文件之后的数字是变化的,0.11版本后,其存储了我们的训练参数。

而与.data .meta .index同时生成的还有一个checkpoint文本文件,其作用是为最近保存的模型文件做记录。

小结:

  1. TTensorflow-0.11版本之前的模型保存后, 文件有三个:
model.meta # Store Graph
model.ckpt # Store variables
checkpoint # Store record latest checkpoint file saved
  1. TTensorflow-0.11版本之前的模型保存后, 文件有四个:
model.meta # Stroe grap
model.data-* # Store variables
model.index
checkpoint # Store record latest checkpoint file saved

2 保存模型

假设说,你在训练一个图像分类的卷积神经网络。作为一个标准操作,需要观察模型的损失函数值和正确率,当模型收敛的时候,你可能想要停止训练或者只运行一定循环数的数据集。当训练结束的时候,通常保存所有的数据和网络图结构到文件就是下一步操作,在Tnesorflow中,是通过tf.train.Saver()类来进行保存模型等相关操作的。

tip:
需要注意的是,Tnesorflow中的变量值只在session运行期间存在,所以我们必须在会话生存期内进行变量的保存:

saver = tf.train.Saver()
with tf.Session() as sess:
	...
   saver.save(sess, 'my-test-model')`

这里,sess是一个session对象,'my-test-model’为想要赋予模型文件的名字,以下为完整代码:

import tensorflow as tf
w_1 = tf.Variables(tf.random_normal(shape[2], name='w_1'))
w_2 = tf.Variables(tf.random_normal(shape=[5]), name='w_2')
saver = tf.train.Saver()
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	saver.save(sess, 'my-test-model')
       
# 以上代码将保存以下文件:
"""
my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta
checkpoint
"""

如果想要在1000个循环后,保存文件,我们可以给出global_step参数:

saver.save(sess, ‘mt-test-model’, global_step=1000)

此时,保存的文件会变成:

saver.save(sess, 'mt-test-model', global_step=1000)
    
# files are saved as follows:
"""
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
"""

进一步说,我们没过1000循环都保存文件,所以.meta文件在第一次就被保存(第一个1000循环到达时),实际上我们不需要每到1000整数次循环就在此重写.meta文件,因为其在训练过程中很少改变,则我们可以加入writer_meta_graph=False:

saver.save(sess, 'my-test-model', global_step=step, write_meta_graph=False)

如果想在磁盘上每隔2小时保存最近的4个模型文件的话,可以使用max_to_keep和keep_checkpoint_every_n_hours:

saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

tip:
如果没有特殊指定,tf.train.Saver()将保存所有变量。

若指向保存部分参数,则可以指定变量名或集合,当创建saver时,直接传入参数:

import tensorflow as tf
w_1 = tf.Variable(tf.random_normal(shape=[2]), name='w_1')
w_2 = tf.Variable(tf.random_normal(shape=[5]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值