Tensorflow计算一个模型的浮点运算数

1、统计模型的浮点运算数和参数量

  • FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
  • FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
  • MACCs:是multiply-accumulate operations),也叫MAdds,意指乘-加操作,理解为计算量,MAdds 大约是 FLOPs 的一半。

    当我们使用tensorflow设计好一个神经网络模型之后,我们如何统计这个模型有多少浮点运算数(FLOPs)和多少参数量呢?我们可以使用tensorflow中的一个模块来进行统计,实例如下(以统计VGG为例):

#coding = utf-8

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import vgg

def stats_graph(graph):
	flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
	params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
	print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.tot
### FLOPs计算方法 FLOPs 是 Floating Point Operations Per Second 的缩写,用于衡量计算机系统的浮点运算能力或者评估深度学习模型中的计算复杂度。以下是关于 FLOPs 计算的具体方法及其工具: #### 手动计算 FLOPs 对于卷积神经网络 (CNN),每层的操作数可以通过分析其结构来手动估算。例如,在卷积层中,假设输入大小为 \(H \times W\),滤波器尺寸为 \(K \times K\),通道数量为 \(C_{in}\),输出通道数量为 \(C_{out}\),步幅为 \(S\),填充为 \(P\),则单次卷积操作的乘法和加法次数可以表示如下[^1]: \[ \text{FLOPs} = H' \cdot W' \cdot C_{out} \cdot K^2 \cdot C_{in} \] 其中, \[ H' = \frac{(H + 2P - K)}{S} + 1,\quad W' = \frac{(W + 2P - K)}{S} + 1 \] 如果仅考虑乘法而不计加法,则上述公式会简化。 #### 自动化工具计算 FLOPs 为了更高效地获取模型FLOPs 数据,通常使用自动化工具完成这一过程。以下是一些常用的框架及相关代码示例: ##### PyTorch 中的 `torchstat` 工具 PyTorch 提供了一些第三方库可以帮助快速统计模型的参数量和 FLOPs 值。例如,`torchstat` 可以通过简单的调用来获得这些数据[^2]: ```python from torchstat import stat import torchvision.models as models model = models.resnet18() input_size = (3, 224, 224) # 输入张量形状 stat(model, input_size) ``` ##### TensorFlow 中的 Flops 统计 在 TensorFlow 中,也可以利用内置的功能模块来进行 FLOPs 的统计[^3]。下面是一个完整的例子展示如何提取总的浮点运算数: ```python import tensorflow.compat.v1 as tf tf.disable_v2_behavior() def calculate_flops(): with tf.Graph().as_default() as graph: sess = tf.Session(graph=graph) # 定义模型并构建图... model_output = ... run_meta = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile( graph=graph, run_meta=run_meta, cmd='op', options=opts ) print(f'Total GFLOPS: {flops.total_float_ops / 1e9}') calculate_flops() ``` 以上两种方式分别适用于不同深度学习框架下的应用需求。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值