因为要和SOTA比较模型的复杂度,我想知道参数数量。但是模型文件不是tensorflow checkpoint,而是pb文件,我发现当导入graph后,tf.trainable_variables()返回空。
Problem setting : I need to compare with state-of-the-arts the model complexity so the model parameter amount is needed. However the provided model file isn’t the ckpt file, but pb file, and the variables returned by tf.trainable_variables()
is found empty.
这个回答给出了方法。
This answer gives the solution.
举例:
In my case:
# import graph
with open('spmc_120_160_4x3f.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, input_map=