TensorFlow学习记录:读取已保存的.ckpt模型

本文详细介绍了如何在TensorFlow中保存和恢复模型,特别是如何从.ckpt文件加载模型。通过实例展示了如何使用tf.train.Saver,包括直接加载计算图以及处理变量名不一致的情况。此外,还讲解了如何利用滑动平均值加载模型,以及在加载VGGNet模型后计算测试集上的准确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

之前写了一篇使用VGGNet对Cifar10数据集分类的博客,里面最后一句代码对已训练好的模型进行了保存,最近学习了如何读取已训练好的.ckpt模型,所以在这里记录一下。

使用VGGNet对Cifar10数据集分类中,代码的最后一行使用了tf.train.Saver类的save()函数将Tensorflow模型保存到了一个指定路径下,然后这个目录下就多出了四个文件,如下图所示:
在这里插入图片描述
其中,checkpoint文件是一个文本文件,用文本编辑器就能打开。它保存了一个目录下所有的模型文件列表。下图为该文件的内容:
在这里插入图片描述
其中,model_checkpoint_path记录的是当前最新模型的路径,all_model_checkpoint_paths记录的是已保存的5个模型的路径。在Tensorflow中训练并保存模型的时候,会自动保存最新的5个模型,如果后面还有模型要保存的话,Tensorflow会自动删除最早保存的模型以及相应的文件,并且也会自动更新checkpoint文件。

1.模型保存为.ckpt文件代码实现

其余3个文件全部都是二进制文件,我们是没有办法直接打开的。这些文件的作用分别如下表所示:

文件名 说明
model.ckpt-199999.data-00000-of-00001 保存了TensorFlow程序中每一个变量的取值
model.ckpt.index 保存了每一个变量的名称,是一个string-string的table,其中table的key值为tensor名,value值为BundleEntryProto,每个BundleEntryProto表述了tensor的metadata(用于解释或帮助理解信息的数据)
model.ckpt.meta 保存了计算图的结构,或者可以说是神经网络的结构

先定义一个简单的两个变量相加的例子,并且保存为模型

import tensorflow as tf
a = tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b = tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result = a+b

init_op = tf.global_variables_initializer()
saver =tf.train.Saver() #	创建一个Saver类

with tf.Session()as sess:
    sess.run(init_op)
    saver.save(sess,"G:/学习/TensorFlow 深度学习算法原理与编程实践/模型持久化/普通模型/model.ckpt") #Saver类的save()函数用于保存模型,sees是当前会话,后面的是保存路径。如果在训练模型的时候后面还可以加一个参数global_step=当前训练步数

恢复模型的时候,先定义模型里面的所有运算过程,并且变量的name属性须与模型中储存的变量的name属性保持一致(其实也可以不一致,具体做法下面会提到),这样就可以通过变量将已保存的模型加载进来。这部分实现的代码如下:

import tensorflow as tf
# 先定义模型里面的所有运算过程,并且变量的name属性须与模型中储存的变量的name属性保持一致
a = tf.Variable(tf.constant([1.0,2.0],dtype=tf.float32,shape=[2]),name="a")
b = tf.Variable(tf.constant([3.0,4.0],dtype=tf.float32,shape=[2]),name="b")
result = a+b

saver = tf.train.Saver() #	创建一个Saver类。注意,这里当Saver类的对象调用restore()函数时会仅恢复上面所定义的运算过程和变量,假设模型中有保存了一个a*b的运算操作,或者对a和b使用了滑动平均算法,如果这里没有重新定义a*b的运算操作,或者对a和b使用滑动平均算法的话,是不会被恢复进来的,所以当你要打印当前默认计算图的时候是不会有a*b的运算操作或者a和b的影子变量的,一定要注意。
with tf.Session()as sess:
    # 使用restore()函数加载已经保存的模型
    saver.restore(sess,"G:/学习/TensorFlow 深度学习算法原理与编程实践/模型持久化/普通模型/model.ckpt")
    print(sess.run(result))
    # 输出[4. 6.]

有时我们觉得重复定义和模型一样的运算过程非常麻烦,希望可以把计算图也恢复进来。Tensorflow支持直接加载已经持久化的计算图。函数import_meta_graph()实现了这个功能,并且还会把所有计算操作都恢复进来,其输入参数为.meta文件的路径(注意.meta文件并没有存储参数的值)。这个函数返回一个Saver类的对象,再调用Saver对象的restore()函数就可以恢复所有参数了。这部分的代码如下:

import tensorflow as tf
#省略了定义计算图上运算的过程,取而代之的是通过.meta文件直接加载持久化的计算图
meta_graph = tf.train.import_meta_graph("G:/学习/TensorFlow 深度学习算法原理与编程实践/模型持久化/普通模型/model.ckpt.meta")

with tf.Session()as sess:
    # 使用restore()函数加载已经保存的模型
    meta_graph.restore(sess,"G:/学习/TensorFlow 深度学习算法原理与编程实践/模型持久化/普通模型/model.ckpt")
    # 获取默认计算图上指定节点处的张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
    print(sess.run(tf.get_default_graph().get_tensor_by_name("a:0")))
    print(sess.run(tf.get_default_graph().get_tensor_by_name("b:0")))
    # 输出
    # [4. 6.]
	# [1. 2.]
	# [3. 4.]

get_default_graph()函数用于获取当前的默认计算图,get_tensor_by_name()函数用于获取计算图中指定节点的张量(add:0表示节点add的第一个输出)。

刚才说了,恢复模型的时候,先定义模型里面的所有运算过程,并且变量的name属性必须与模型中储存的变量的name属性保持一致,如果变量的name属性和模型中变量的name属性不一致的话就不行了码?其实是有办法解决的,可以在初始化Saver类时以字典(Dictionary)的方法,将新定义的变量和模型中的变量联系起来,使用这种方法初始化Saver类时的代码为:

saver = tf.train.Saver({
   
   "a":a,"b":b})

初始化Saver类时使用字典的方法指定了模型中的变量a和新定义的变量a(name属性可以与模型中的变量a的name属性不一样)相关联,模型中变量b和新定义的变量b相关联。

通常我们不会刻意更改变量的name属性,然后使用这种字典的方式进行关联,但当我们在使用变量滑动平均值的时候,如果在加载模型时直接将影子变量的值赋值给原变量,那么就可以省略单独调用函数来获取变量滑动平均值的过程了,这样大大方便了滑动平均模型的使用。下面的代码先定义两个变量和计算这两个变量的滑动平均值,并保存它们。

import tensorflow as tf
a = tf.Variable(0,dtype=tf.float32,name="a")
b = tf.Variable(0,dtype=tf.float32,name="b")

# 先创建滑动平均类,再通过apply()函数将滑动平均算法赋值给所有变量
average_class = tf.train.ExponentialMovingAverage(0.99)
average_op = average_class.apply(tf.all_variables())

for variables
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值