Tensorflow 获取model中的变量列表,用于模型加载等

本文详细介绍在TensorFlow中加载预训练模型时,如何提取特定格式的变量列表,包括使用TensorFlow内置工具、slim库、从保存的model中提取以及直接读取离线文件等多种方法。
部署运行你感兴趣的模型镜像

目录

前言

1. 用tensorflow自带的工具

2. 用tensorflow.contrib.slim。

3. 从保存的model中提取var_list

4. 其他


前言

在加载预训练的网络模型时,有时仅需要利用保存的网络模型中的部分变量,因此我们需要提取这些变量的变量列表,用于变量加载。如以下示例所示,当用tensorflow.train.Saver(var_list)加载模型参数时,需要传入的var_list需要一种特定的数据格式,而直接使用str组成的变量名列表是无效的。

var_list = saver._var_list
print(type(var_list[0]))
#output : <class 'tensorflow.python.ops.variables.RefVariable'>
print(var_list[0])
#output : <tf.Variable 'g_b1:0' shape=(32768,) dtype=float32_ref>

那么,如何获取指定格式的变量列表呢?主要有以下方式:

1. 用tensorflow自带的工具

  • (1) 获取图(graph)中所有变量,包括可训练的与不可训练(training=False)的变量
all_vars = tf.global_variables()
  •  (2)获取图中可训练的变量。
train_vars = tf.trainable_variables()
  • (3)根据变量名在已有变量列表中筛选变量。在(1)或(2)中获取的变量列表中进行筛选
# 从all_vars列表中筛选以”generator“开头的变量,组成一个新的变量列表。
g_vars =[var for var in all_vars for var.name.startswith("generator")]

# 从all_vars列表中筛选name中包含”batch_normalization“的变量
bn_var_list = [v for v in all_vars if "batch_normalization" in v.name ] 
  • (4)提取某个特定变量空间中的变量列表
# 提取variable_scope或name_scope ”generator“下的所有变量的变量列表。
g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="generator")

# 提取variable_scope或name_scope ”generator“下的所有可训练变量的变量列表。
g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="generator")

2. 用tensorflow.contrib.slim。

tensorflow.contrib.slim是一个强大的工具,也可以用来方便的构建变量列表

  • (1)返回常规变量,应该是返回所有变量(没测试)
import tensorflow.contrib.slim as slim
#下面这行代码返回常规变量
#常规变量是slim里面与model变量对应的一个类型
regular_variables = slim.get_variables()
#你也可以直接
vars = slim.get_variables_to_restore()
     
  • (2)根据条件,筛选变量,构建变量列表
#通过name或前缀筛选
variables = slim.get_variables_by_name("d_")

#通过name后缀筛选
variables = slim.get_variables_by_suffix("_b")

#通过namespace筛选
variables = slim.get_variables(scope="layer1")

#通过include和exclude特定字符串筛选
variables_to_restore = slim.get_variables_to_restore(include=["d_"])
variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])

3. 从保存的model中提取var_list

方法,将离线文件载入当前环境(加载保存的model的图结构),然后使用以上的方法。

注意:在加载model中的图时,要先清空现有的图,否则,import_meta_graph会把原model里面的数据追加到现有的model中,导致一片混乱

# 清空图  
tf.reset_default_graph()
     
with tf.Session(graph=tf.get_default_graph()) as sess:

    # 从model.meta文件中加载保存的图结构
    new_saver = tf.train.import_meta_graph('./model.meta')

    # 从model中加载保存的变量数据
    new_saver.restore(sess, './model')

    # 然后从加载得到的图中获取变量列表,如
    var_list=tf.global_variables()
    print(var_list)

下面是一个从保存的模型中加载图结构,加载变量数值,并打印特定变量的变量值的示例。

import tensorflow as tf

model_path = "/media/***/model.ckpt-18000"

tf.reset_default_graph() # 清空图,防止图上存在干扰节点
with tf.Session(graph=tf.get_default_graph()) as sess:
	saver = tf.train.import_meta_graph(model_path+".meta")# 从model.meta文件中加载保存的图结构
	saver.restore(sess, model_path)# 从model中加载保存的变量数据
	
	var_list = tf.global_variables() # 获取图上的所有变量
	bn_var_list = [v for v in var_list if "batch_normalization" in v.name ] # 筛选与BatchNormalization相关的变量
	for v in bn_var_list: 
		print(v,sess.run(v)[1]) # 打印变量名,与对应的变量值。为了便于观察,打印变量张量的第一个数值

 

4. 其他

  • (1)用pywrap_tensorflow 直接读取离线文件中的变量信息,但是,这种方法获取到的var_list与要求的格式不一样。

 我还不知道怎么转换成saver要求的类型,等一个有缘人相告!

    import tensorflow as tf
    from tensorflow.python import pywrap_tensorflow
     
    #文件夹地址改成自己的
    model_dir="./model"
     
    ckpt = tf.train.get_checkpoint_state(model_dir)
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
     
    #返回一个dict= {'name':[shape] }
    #例如 'd_w2/Adam':[4, 4, 32, 64]
    var_to_shape_map = reader.get_variable_to_shape_map()
     
    #我们可以用遍历的方式,取出字典里所有的key
    for key in var_to_shape_map:
        print(key)        #key是str类型的
        #再用key去找这个tensor的值
        a=reader.get_tensor(key)
        print(type(a))    #输出: <class 'numpy.ndarray'>
  • (2)把获得的list输出到一个外部txt文件中,以后使用只需要读取txt即可

          PS:目前这种方法只支持str类型保存变量name!不支持saver!

代码示例:

 with open("var_list.txt","w") as f:
     f.write( ','.join(str(var_list)))   #write只能打印str类型,需要强制转换

效果如图:

要使用的时候,只需要按如下代码读取

 with open("var_list.txt","r") as f:
     str1 = f.readlines()[0]
     var_list=str1.split(',')
     print(var_list)

 

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值