tensorflow 模型浮点数计算量和参数量估计

本文介绍如何使用TensorFlow的profiler工具统计模型的计算量(FLOPs)和参数量,通过freeze计算图去除初始化操作的影响,以获得更准确的终端推理计算量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

写在前面的话

之前在pytorch统计参数量和计算量用stat非常的方便,现在转到tensorflow后发现tensorflow的参数量和计算量的统计相对来说就没那么方便了。
在网上也搜了一些相关教程,发现用profiler来进行统计是最方便的。

本文参考了:tensorflow 模型浮点数计算量和参数量估计TensorFlow程序分析(profile)实战,他们都整理的都非常好,我在这里重新整理一遍方便自己看。

计算量、参数量实验

stats_graph.py

import tensorflow as tf
def stats_graph(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))

搭建一个Graph实验一下:

with tf.Graph().as_default() as graph:
    A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A')
    B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B')
    C = tf.matmul(A, B, name='ouput')
    
    stats_graph(graph)

输出为:

FLOPs: 8288; Trainable params: 544

分析:
先计算一下

C[25,9]=A[25,16]B[16,9] 
FLOPs=(16+15)×(25×9)=6975
FLOPs(inTFstyle)=(16+16)×(25×9)=7200
total_parameters=25×16+16×9=544

FLOPs为什么不一样呢?因为利用高斯分布对变量进行初始化会耗费一定的 FLOP
这个问题很关键,涉及到你最后计算量的结果正确性。

通常我们对耗费在初始化上的 FLOPs 并不感兴趣,因为它是发生在训练过程之前且是一次性的,我们感兴趣的是模型部署之后在生产环境下的 FLOPs。我们可以通过 Freeze 计算图的方式得到除去初始化 FLOPs 的、模型部署后推断过程中耗费的 FLOPs。

from tensorflow.python.framework import graph_util
def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
with tf.Graph().as_default() as graph:
    # ***** (1) Create Graph *****
    A = tf.Variable(initial_value=tf.random_normal([25, 16]))
    B = tf.Variable(initial_value=tf.random_normal([16, 9]))
    C = tf.matmul(A, B, name='output')
    
    print('stats before freezing')
    stats_graph(graph)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # ***** (2) freeze graph *****
        output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
        with tf.gfile.GFile('graph.pb', "wb") as f:
            f.write(output_graph.SerializeToString())
# ***** (3) Load frozen graph *****
graph = load_pb('./graph.pb')
print('stats after freezing')
stats_graph(graph)

这样得到的结果

stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0

可以发现,freezing之后的计算量才是我们需要的终端推理的时候的计算量。所以我最终选择的是这种统计方式。比较遗憾的是,这时候得到的Trainable params是0,因此想要看参数量还是得先进行上一步。

上述步骤我实验过,是可以的。在实际用的时候要注意input要在graph里面(因为有时候你可能把input的tensor在graph前面就定义好了,我就遇到了这么一个问题)

统计模型的内存、耗时情况

这步我目前还没有做,不确定能否成功

# 统计模型内存和耗时情况
builder = tf.profiler.ProfileOptionBuilder
opts = builder(builder.time_and_memory())
#opts.with_step(1)
opts.with_timeline_output('timeline.json')
opts = opts.build()

#profiler.profile_name_scope(opts) # 只能保存单step的timeline
profiler.profile_graph(opts) # 保存各个step的timeline

有问题多交流,可留言可发邮件,我的邮箱是zhaodongyu艾特pku(这里换成点)edu.cn。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值