Tensorfolw:模型持久化与迁移学习

本文介绍了TensorFlow中模型持久化的两种方式,包括保存为CKPT文件和PB文件。CKPT文件包含模型参数和网络结构,而PB文件能将整个模型保存在单个文件中,便于选择特定节点进行保存和读取。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Tensorflow的模型持久化主要有两种方式,一种是保存为CKPT文件(通过tf.train.Saver()类),一种是保存为pb文件(通过graphutil)。

通过checkpoint

先来看看通过ckpt文件持久化模型的例程:

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0),
                 name='v1')
v2 = tf.Variable(tf.constant(2.0),
                 name='v2')
result = tf.add(v1,v2)

init = tf.global_variabels_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess,
           'path/to/model/model.ckpt')

在指定路径下,当前的session被保存在来三个文件下,像这样:


ckpt.data文件是保存模型中所有参数的值

ckpt.index文件保存参数对应的变量名

ckpt.meta保存对应的网络结构,也就是graph

要恢复保存的模型,首先需要导入图的结构(tf.train.import_meta_graph()),然后再恢复各参数权重值(tf.train.Saver.restore()),,这样就得到了保存的整个模型,可以通过name得到模型中任意tensor(saver.restore(sess,tensor_name))。例程如下:

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('ckpt/model.ckpt.meta')
    saver.restore(sess,tf.train.latest_checkpoint('ckpt/'))

    print(sess.run('conv1/conv1_w1:0'))

也可以不专门导入graph,直接实例化一个Saver对象,把ckpt的路径给这个对象,让它自己找图和权重,然后恢复:

with tf.Session() as sess:
    saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint('path/to/ckpt')
    saver.restore(sess,ckpt)
    print(sess.run('conv1/conv_w1:0'))

通过.pb文件 

前面介绍了通过tf.train.Saver()持久化模型的方法,我们可以看到每次保存模型会生成好几个文件,权重和图结构分布在两个不同的文件中,实际应用中会带来一些不便。graph包括了图中各个节点和常量,如果把所有变量值变成常量,就可以成为图的一部分,把整个模型保存在一个文件中。这样还有一个好处,就是可以选择图中指定的节点保存下来,不用像Saver()一样保存一切的节点和变量。

方法:1.利用convert_variables_to_constants函数将sess中选定的节点固定为常量,并放到序列化图模型中(graph_def)

         2.利用gfile.GFlile函数创建一个.pb文件

         3.将放入了变量的graph_def写入.pb文件中


with tf.Session() as sess:
    graph_def = tf.get_default_graph().as_graph_def()
    graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['opt_name'])
    with gfile.GFile('model.pb','wb') as f:
        f.write(graph_def.SerializeToString())

读取.pb文件中保存的模型:

      1.以二进制形式打开保存的.pb文件

      2.通过ParseFromString()函数读取.pb中的序列化图模型信息导入当前默认graph_def中

      3.通过tf.import_graph_def()导入图

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    graph_def = tf.get_default_graph().as_graph_def()
    with gfile.GFile('pb/model.pb','rb') as f:
        graph_def.ParseFromString(f.read())
    output = tf.import_graph_def(graph_def,return_elements=['conv1/conv_w1:0'])
    print(sess.run(output))
因为.pb的图模型中保存的节点都是常数,我们常用它来保存已经训练好的网络模型,导出后可以直接进行inference。在transfer learning中,输入就是bottleneck节点inference的结果,因此我们常用.pb文件保存好的预训练模型做transfer learning 。



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值