pytorch中获取模型input/output shape

本文介绍了在PyTorch中如何通过代码获取模型的输入和输出形状,虽然官方未提供直接方法,但可以通过构造输入并调用forward来获取各个层的shape信息。示例中展示了对CNN和RNN模型的shape计算。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043


以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape


#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json


def get_output_size(summary_dict, output):
  if isinstance(output, tuple):
    for i in xrange(len(output)):
      summary_dict[i] = OrderedDict()
      summary_dict[i] = get_output_size(summary_dict[i],output[i])
  else:
    summary_dict['output_shape'] = list(output.size())
  return summary_dict

def summary(input_size, model):
  def register_hook(module):
    def hook(module, input, output):
      class_name = str(module.__class__).split('.')[-1].split("'")[0]
      module_idx = len(summary)

      m_key = '%s-%i' % (class_name, module_idx+1)
      summary[m_key] = OrderedDict()
      summary[m_key]['input_shape'] = list(input[0].size())
      summary[m_key] = get_output_size(summary[m_key], output)

      params = 0
      if hasattr(module, 'weight'):
        params += torch.prod(torch.LongTensor(list(module.weight.size())))
        if module.weight.requires_grad:
          summary[m_key]['trainable'] = True
        else:
          summary[m_key]['trainable'] = False
      #if hasattr(module, 'bias'):
      #  params +=  torch.prod(torch.LongTensor(list(module.bias.size())))

      summary[m_key]['nb_params'] = params
      
    if not isinstance(module, nn.Sequential) and \
       not isinstance(module, nn.ModuleList) and \
       not (module == model):
      hooks.append(module.register_forward_hook(hook))
  
  # check if there are multiple inputs to the network
  if isinstance(input_size[0], (list, tuple)):
    x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
  else:
    x = Variable(torch.rand(1,*input_size))

  # create properties
  summary = OrderedDict()
  hooks = []
  # register hook
  model.apply(register_hook)
  # make a forward pass
  model(x)
  # remove these hooks
  for h in hooks:
    h.remove()

  return summary

crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)
以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1"

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值