import torch
from flops_counter import get_model_complexity_info
input_shape = (3, 320, 568)
split_line = '=' * 30
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
flops_counter.py:
# Modified from flops-counter.pytorch by Vladislav Sovrasov
# original repo: https://github.com/sovrasov/flops-counter.pytorch
# MIT License
# Copyright (c) 2018 Vladislav Sovrasov
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import sys
from functools import partial
import numpy as np
import torch
import torch.nn as nn
def get_model_complexity_info(model,
input_shape,
print_per_layer_stat=True,
as_strings=True,
input_constructor=None,
flush=False,
ost=sys.stdout):
"""Get complexity information of a model.
This method can calculate FLOPs and parameter counts of a model with
corresponding input shape. It can also print complexity information for
each layer in a model.
Supported layers are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.
Args:
model (nn.Module): The model for complexity calculation.
input_shape (tuple): Input shape used for calculation.
print_per_layer_stat (bool): Whether to print complexity information
for each layer in a model. Default: True.
as_strings (bool): Output FLOPs and params counts in a string form.
Default: True.
input_constructor (None | callable): If specified, it takes a callable
method that generates input. otherwise, it will generate a random
tensor with input shape to calculate FLOPs. Default: None.
flush (bool): same as that in :func:`print`. Default: False.
ost (stream): same as ``file`` param in :func:`print`.
Default: sys.stdout.
Returns:
tuple[float | str]: If ``as_strings`` is set to True, it will return
FLOPs and parameter counts in a string format. otherwise, it will
return those in a float number format.
"""
assert type(input_shape) is tuple
assert len(input_shape) >= 1
assert isinstance(mod