写在前面的话
之前在pytorch统计参数量和计算量用stat非常的方便,现在转到tensorflow后发现tensorflow的参数量和计算量的统计相对来说就没那么方便了。
在网上也搜了一些相关教程,发现用profiler来进行统计是最方便的。
本文参考了:tensorflow 模型浮点数计算量和参数量估计和TensorFlow程序分析(profile)实战,他们都整理的都非常好,我在这里重新整理一遍方便自己看。
计算量、参数量实验
stats_graph.py
import tensorflow as tf
def stats_graph(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
搭建一个Graph实验一下:
with tf.Graph().as_default() as graph:
A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A')
B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B')
C = tf.matmul(A, B, name='ouput')
stats_graph(graph)
输出为:
FLOPs: 8288; Trainable params: 544
分析:
先计算一下
C[25,9]=A[25,16]B[16,9]
FLOPs=(16+15)×(25×9)=6975
FLOPs(inTFstyle)=(16+16)×(25×9)=7200
total_parameters=25×16+16×9=544
FLOPs为什么不一样呢?因为利用高斯分布对变量进行初始化会耗费一定的 FLOP
这个问题很关键,涉及到你最后计算量的结果正确性。
通常我们对耗费在初始化上的 FLOPs 并不感兴趣,因为它是发生在训练过程之前且是一次性的,我们感兴趣的是模型部署之后在生产环境下的 FLOPs。我们可以通过 Freeze 计算图的方式得到除去初始化 FLOPs 的、模型部署后推断过程中耗费的 FLOPs。
from tensorflow.python.framework import graph_util
def load_pb(pb):
with tf.gfile.GFile(pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
with tf.Graph().as_default() as graph:
# ***** (1) Create Graph *****
A = tf.Variable(initial_value=tf.random_normal([25, 16]))
B = tf.Variable(initial_value=tf.random_normal([16, 9]))
C = tf.matmul(A, B, name='output')
print('stats before freezing')
stats_graph(graph)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# ***** (2) freeze graph *****
output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
with tf.gfile.GFile('graph.pb', "wb") as f:
f.write(output_graph.SerializeToString())
# ***** (3) Load frozen graph *****
graph = load_pb('./graph.pb')
print('stats after freezing')
stats_graph(graph)
这样得到的结果
stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0
可以发现,freezing之后的计算量才是我们需要的终端推理的时候的计算量。所以我最终选择的是这种统计方式。比较遗憾的是,这时候得到的Trainable params是0,因此想要看参数量还是得先进行上一步。
上述步骤我实验过,是可以的。在实际用的时候要注意input要在graph里面(因为有时候你可能把input的tensor在graph前面就定义好了,我就遇到了这么一个问题)
统计模型的内存、耗时情况
这步我目前还没有做,不确定能否成功
# 统计模型内存和耗时情况
builder = tf.profiler.ProfileOptionBuilder
opts = builder(builder.time_and_memory())
#opts.with_step(1)
opts.with_timeline_output('timeline.json')
opts = opts.build()
#profiler.profile_name_scope(opts) # 只能保存单step的timeline
profiler.profile_graph(opts) # 保存各个step的timeline
有问题多交流,可留言可发邮件,我的邮箱是zhaodongyu艾特pku(这里换成点)edu.cn。