使用MMEngine计算模型FLOPs和参数量详解

使用MMEngine计算模型FLOPs和参数量详解

【免费下载链接】mmengine OpenMMLab Foundational Library for Training Deep Learning Models 【免费下载链接】mmengine 项目地址: https://gitcode.com/gh_mirrors/mm/mmengine

什么是FLOPs和参数量

在深度学习模型开发过程中,了解模型的**计算复杂度(FLOPs)参数量(Parameters)**非常重要:

  • FLOPs(Floating Point Operations):浮点运算次数,衡量模型的计算复杂度
  • 参数量:模型中所有可训练参数的总数,影响模型大小和内存占用

这两个指标直接影响模型的:

  • 推理速度
  • 内存占用
  • 部署难度
  • 训练成本

MMEngine中的模型分析工具

MMEngine提供了get_model_complexity_info函数,可以方便地计算模型的FLOPs和参数量。下面我们通过一个完整的示例来演示如何使用。

完整使用示例

1. 定义模型

首先我们需要定义一个继承自BaseModel的模型类。这里我们以ResNet50为例:

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels=None, mode='tensor'):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
        elif mode == 'tensor':
            return x

2. 计算FLOPs和参数量

使用get_model_complexity_info函数计算模型复杂度:

from mmengine.analysis import get_model_complexity_info

# 定义输入张量的形状 (通道数, 高度, 宽度)
input_shape = (3, 224, 224)
model = MMResNet50()

# 获取分析结果
analysis_results = get_model_complexity_info(model, input_shape)

3. 查看分析结果

分析结果提供了两种展示方式:

表格形式展示
print(analysis_results['out_table'])

输出结果会以表格形式展示每个模块的详细参数和计算量:

+------------------------+----------------------+------------+--------------+
| module                 | #parameters or shape | #flops     | #activations |
+------------------------+----------------------+------------+--------------+
| resnet                 | 25.557M              | 4.145G     | 11.115M      |
|  conv1                 |  9.408K              |  0.118G    |  0.803M      |
|   conv1.weight         |   (64, 3, 7, 7)      |            |              |
|  bn1                   |  0.128K              |  4.014M    |  0           |
|   bn1.weight           |   (64,)              |            |              |
|   bn1.bias             |   (64,)              |            |              |
|  layer1                |  0.216M              |  0.69G     |  4.415M      |
... (后续详细表格数据)
模型结构形式展示
print(analysis_results['out_arch'])

输出结果会按照模型结构层级展示:

MMResNet50(
#params: 25.56M, #flops: 4.14G, #acts: 11.11M
(data_preprocessor): BaseDataPreprocessor(#params: 0, #flops: N/A, #acts: N/A)
(resnet): ResNet(
    #params: 25.56M, #flops: 4.14G, #acts: 11.11M
    (conv1): Conv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    #params: 9.41K, #flops: 0.12G, #acts: 0.8M
    )
    (bn1): BatchNorm2d(
    64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    #params: 0.13K, #flops: 4.01M, #acts: 0
    )
    ... (后续详细层级结构)
)

结果解读

从分析结果可以看出:

  1. 总参数量:25.56百万(25.56M)
  2. 总计算量:4.14十亿次浮点运算(4.14G FLOPs)
  3. 激活值数量:11.11百万(11.11M)

进一步分析各层贡献:

  • 第一层卷积(conv1)占用了约9.4K参数和0.12G FLOPs
  • 第一个残差块(layer1)占用了约0.22M参数和0.69G FLOPs
  • 全连接层(fc)占用了约2M参数

实际应用建议

  1. 模型优化:通过分析结果可以找出计算量大的层,针对性优化
  2. 部署准备:了解模型计算量有助于选择合适的硬件
  3. 模型对比:可以客观比较不同模型的计算效率
  4. 研究分析:帮助理解模型各部分的计算贡献

注意事项

  1. 输入形状会影响FLOPs计算结果,需根据实际应用场景设置
  2. 某些特殊操作可能无法准确统计计算量
  3. 实际运行时的性能还受内存访问、并行度等因素影响

通过MMEngine提供的模型分析工具,开发者可以快速获取模型的复杂度指标,为模型优化和部署提供重要参考。

【免费下载链接】mmengine OpenMMLab Foundational Library for Training Deep Learning Models 【免费下载链接】mmengine 项目地址: https://gitcode.com/gh_mirrors/mm/mmengine

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值