加载tensorflow ckpt 模型

1、加载训练好的tf模型,查看变量

# 方式1,导出模型需要设置export_embedding=true
export_dir = '/export/wangjie648/cornerstone/data/model/'
# tf.contrib.resampler
saver = tf.train.import_meta_graph(export_dir+'model.ckpt-10555.meta')
variable_names = [v.name for v in tf.trainable_variables()]
config = tf.ConfigProto(allow_soft_placement=True)
import numpy as np
with tf.Session(config=config) as sess:
    ckpt = tf.train.get_checkpoint_state(export_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    print(ckpt.model_checkpoint_path)
    values = sess.run(variable_names)
    for k,v in zip(variable_names, values):
        if 'embedding_weights' in k:
            print("Variable: ", k)
            print("Shape: ", v.shape)
            print("type:", type(v))
    np.save('embedding_120w.txt', v)
 
 
# 方式2
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name='/export/wangjie648/notbook/z_model/teacher.2021.Q1.data.student.model/model.ckpt-12379767', tensor_name='', all_tensors=False)
# print_tensors_in_checkpoint_file(file_name='/export/wangjie648/notbook/z_model/model/model.ckpt-14562',tensor_name='input_layer/word/embedding_weights', all_tensors=False)
# 方式3
import tensorflow as tf
 
# checkpoint_file = '/export/wangjie648/notbook/z_model/model/model.ckpt-14562'
checkpoint_file = '/export/wangjie648/notbook/z_model/teacher.2021.Q1.data.student.model/model.ckpt-12379767'
 
reader = tf.train.NewCheckpointReader(checkpoint_file)
#print(reader.debug_string().decode("utf-8"))
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    if key=='input_layer/word/embedding_weights':
        print("tensor_name: ", key)   # 打印变量名
        print(reader.get_tensor(key).shape) # 打印变量值

# 方式4
reader = tf.train.load_checkpoint('/media/cfs/wangjie648/local/jxpp-distillation_local_two_label_v4')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
[i  for i in shape_from_key.keys() if 21128 in shape_from_key[i]]
key = 'embeddings/embeddings/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)

2、导出tf.keras.model参数(summary方法)

with open(model_dir + '/model_summary.txt','w') as fh:
    model.summary(print_fn=lambda x: fh.write(x + '\n'))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值