pytorch 笔记:torchsummary、计算模型参数量

1 torchsummary

作用:打印神经网络的结构

  • 总结给定的PyTorch模型。总结的信息包括:

1) 层名称,

2) 输入/输出形状,

3) 核形状,

4) 参数数量,

5) 操作数量(乘加操作)

1.1 参数

model

nn.Module

要总结的PyTorch模型。模型应完全处于train()或eval()模式

如果层不全部处于同一模式,运行总结可能会影响批标准化或dropout统计

input_data

模型的示例输入张量(从模型输入推断数据类型)

- 或 -

输入数据的形状,以List/Tuple/torch.Size形式 (数据类型必须与模型输入匹配,默认为FloatTensors)。

batch_dim
输入数据的批量维度。如果batch_dim为None,则假定输入数据包含批量维度
branching是否使用分支布局打印输出。
col_names

指定输出中要显示的列。

当前支持: ("input_size", "output_size", "num_params", "kernel_size", "mult_adds")

如果未提供输入数据,只使用"num_params"。

默认:("output_size", "num_params")

col_width每列的宽度。 默认:25
depth遍历的嵌套层次数(例如Sequentials)。 默认:3
device

使用此torch设备为模型和输入数据。

如果未指定,使用torch.cuda.is_available()的结果。

dtypes对于多个输入,指定两个输入的大小, 并在这里指定每个参数的类型。
verbose

0(静音):无输出

1(默认):打印模型概要

2(详细):详细显示权重和偏置层

默认:1

1.2 举例

pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-优快云博客 中搭建的CNN为例

import torch
from torchsummary import summary

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
 
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=1,
#输入shape (1,28,28)
                out_channels=16,
#输出shape(16,28,28),16也是卷积核的数量
                kernel_size=5,
                stride=1,
                padding=2),
#如果想要conv2d出来的图片长宽没有变化,那么当stride=1的时候,padding=(kernel_size-1)/2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
 #在2*2空间里面下采样,输出shape(16,14,14)
        )
           
        self.conv2=nn.Sequential(
            nn.Conv2d(
                in_channels=16,
#输入shape (16,14,14)
                out_channels=32,
#输出shape(32,14,14)
                kernel_size=5,
                stride=1,
                padding=2),
#输出shape(32,7,7),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
 
        self.fc=nn.Linear(32*7*7,10)
#输出一个十维的东西,表示我每个数字可能性的权重
        
    def forward(self,x):
            x=self.conv1(x)
            x=self.conv2(x)
            x=x.view(x.shape[0],-1)
            x=self.fc(x)
            return x
    
cnn=CNN()
summary(cnn,(1,28,28))

输出的结果是这样的:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 28, 28]             416
              ReLU-2           [-1, 16, 28, 28]               0
         MaxPool2d-3           [-1, 16, 14, 14]               0
            Conv2d-4           [-1, 32, 14, 14]          12,832
              ReLU-5           [-1, 32, 14, 14]               0
         MaxPool2d-6             [-1, 32, 7, 7]               0
            Linear-7                   [-1, 10]          15,690
================================================================
Total params: 28,938
Trainable params: 28,938
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.32
Params size (MB): 0.11
Estimated Total Size (MB): 0.44
----------------------------------------------------------------

2  手动计算参数量

还是使用上面说的cnn,我们结合torch的numel()方法实现之

num=0
for i in cnn.parameters():
    if i.requires_grad==True:
        num+=i.numel()
num
#28938
sum(i.numel() for i in cnn.parameters() if i.requires_grad==True)
#28938

不难发现和上面的结果是一样的

是什么以及如何使用它查看模型的结构和参数量torchsummary是一个用于查看PyTorch模型结构和参数量的工具。通过使用torchsummary.summary()函数,可以输出模型的输入和输出的形状,以及网络的顺序结构、参数量模型大小等信息。 要使用torchsummary,首先需要在Anaconda prompt中进入自己的PyTorch环境,并安装依赖包torchsummary,可以使用pip install torchsummary命令进行安装。 然后,在代码中导入torchsummary模块,并调用summary()函数,传入模型、输入大小、批次大小和设备等参数。summary()函数将会输出模型的结构和参数量等信息。 这样,就可以使用torchsummary工具来更加清晰地查看PyTorch模型的结构和参数量了。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [pytorch 中的torchsummary](https://blog.youkuaiyun.com/qq_41468616/article/details/121164258)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [pytorch 网络可视化(一):torchsummary](https://blog.youkuaiyun.com/Wenyuanbo/article/details/118514709)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值