这个库可以查看用pytorch搭建的模型,各层输出尺寸。
1.安装
相信大家使用到现在这个地步,导入库已经很熟练了。里面的坑我就不多说。
pip install torchsummary
2.导入
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
上述的your_model是搭建的模型,input_size是你输入的数据。比如输入一张图片大小为28*28的RGB(3个通道)的图片。那么input_size=(3,28,28),模型是根据你输入的大小来输出中间过程的模型尺寸。而且这里不能输入多张比如input_size(8, 3, 28, 28),你想输入八张看效果,不可以的,会报错。
3.使用实例介绍
这里简单导入一个VGG的模型,介绍一下使用情况:
import torch
from torchsummary import summary
from torchvision.models import vgg11
model = vgg11(pretrained=False) #使用现成的VGG11的网络结构
if torch.cuda.is_available():
model.cuda()
summary(model, (3, 224, 224))
输出结果:
----------------------------------------------------------------
Layer (type) Output