pytorch使用torchvision0.2版本-SqueezeNet报错问题解决:Calculated output size: (5x0x0). Output size is too small

SqueezeNet在PyTorch不同版本的适配问题
本文解决在PyTorch 1.0和torchvision 0.2版本下,使用SqueezeNet进行图像分类时遇到的RuntimeError问题。问题源于平均池化操作的kernel_size与输入尺寸不匹配,导致输出尺寸过小。通过对比torchvision 0.3版本的解决方案,即使用自适应平均池化,成功解决了该问题。
部署运行你感兴趣的模型镜像


  • pytorch版本:1.0
  • torchvision版本:0.2.0
  • Python版本:3.7

问题描述如下:

RuntimeError: Given input size: (5x12x12). Calculated output size:
(5x0x0). Output size is too small at
c:\a\w\1\s\tmp_conda_3.7_105830\conda\conda-bld\pytorch_1544094179692\work\aten\src\thcunn \generic/SpatialAveragePooling.cu:47

下面为网络结构代码:

from torchvision.models import squeezenet1_0, squeezenet1_1
import torch.nn as nn
import torch

class SqueezeNet0(nn.Module):
    def __init__(self, num_classes=5):
        super(SqueezeNet0, self).__init__()
        self.num_classes = num_classes
        self.squeezenet = squeezenet1_0(pretrained=True)
        self.squeezenet.features[0] = nn.Conv2d(64, 96, kernel_size=7, stride=2)
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.squeezenet.classifier[1] = final_conv
    def forward(self, x):
        return self.squeezenet(x)

torchvision 0.2.0定义的squeezenet.py

图片的输入尺寸是200x200;

在我自己的本地python环境下报错,但是在服务器上却没有问题。
查看了源码发现,torchvision 0.2.0版本的中的squeezenet.py定义的最后平均池化操作比较死,平均池化的kernel_size定义的是13,而我查看上一个网络层的输出却是5x12x12,因此输出尺寸太小,不匹配平均池化操作的kernel_size,所以才有输出尺寸为0的报错。

      final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AvgPool2d(13, stride=1)  # 问题所在
        )

torchvision 0.3.0定义的squeezenet.py

服务器的torchvision版本是0.3.0,其可以正常运行一样的代码,查看源码,发现其是这样定义的squeezenet,使用的是nn中定义的自适应平均池化,直接将其池化到1x1的size大小,所以正常运行输入尺寸是200x200大小的图片。

      final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1)   # 问题所在
        )

希望能帮到大家。
2020.1.4


您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

Traceback (most recent call last): File "D:\pythonproject\pythonProject\AlexNet\model_train.py", line 196, in <module> train_process = train_model_process(AlexNet, train_data_loader, val_data_loader, num_epochs=1) File "D:\pythonproject\pythonProject\AlexNet\model_train.py", line 102, in train_model_process output = model(b_x) File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "D:\pythonproject\pythonProject\AlexNet\model.py", line 28, in forward x = self.pool4(x) File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\nn\modules\pooling.py", line 224, in forward return F.max_pool2d( File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\_jit_internal.py", line 627, in fn return if_false(*args, **kwargs) File "C:\Users\shen\.conda\envs\pytorch_env\lib\site-packages\torch\nn\functional.py", line 827, in _max_pool2d return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) RuntimeError: Given input size: (256x2x2). Calculated output size: (256x0x0). Output size is too small
最新发布
09-26
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值