1、加载训练好的tf模型,查看变量
export_dir = '/export/wangjie648/cornerstone/data/model/'
saver = tf.train.import_meta_graph(export_dir+'model.ckpt-10555.meta')
variable_names = [v.name for v in tf.trainable_variables()]
config = tf.ConfigProto(allow_soft_placement=True)
import numpy as np
with tf.Session(config=config) as sess:
ckpt = tf.train.get_checkpoint_state(export_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
values = sess.run(variable_names)
for k,v in zip(variable_names, values):
if 'embedding_weights' in k:
print("Variable: ", k)
print("Shape: ", v.shape)
print("type:", type(v))
np.save('embedding_120w.txt', v)
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name='/export/wangjie648/notbook/z_model/teacher.2021.Q1.data.student.model/model.ckpt-12379767', tensor_name='', all_tensors=False)
import tensorflow as tf
checkpoint_file = '/export/wangjie648/notbook/z_model/teacher.2021.Q1.data.student.model/model.ckpt-12379767'
reader = tf.train.NewCheckpointReader(checkpoint_file)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
if key=='input_layer/word/embedding_weights':
print("tensor_name: ", key)
print(reader.get_tensor(key).shape)
reader = tf.train.load_checkpoint('/media/cfs/wangjie648/local/jxpp-distillation_local_two_label_v4')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
[i for i in shape_from_key.keys() if 21128 in shape_from_key[i]]
key = 'embeddings/embeddings/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
2、导出tf.keras.model参数(summary方法)
with open(model_dir + '/model_summary.txt','w') as fh:
model.summary(print_fn=lambda x: fh.write(x + '\n'))