RestNet18尝试,报错修改。

新手小白记录,如有错误,请谅解,该文章仅按个人理解编写。

测试网络时,运行报错:

RuntimeError: Given groups=1, weight of size 128 64 3 3, expected input[2, 3, 32, 32] to have 64 channels, but got 3 channels instead。

 解决办法,就是把输入的[2,3,32,32]这个通道数3改成对应的卷积核的通道数即可。

正确代码如下:

import torch
import torch.nn.functional as F
from torch import nn

class Resblock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(Resblock, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_in != ch_out:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # short cut
        # element - Module:[b,ch_in,ch_out,h,w],[b,ch_in,out,h,w]
        # element - wise add:

        print("out", out.shape)       
        print("x", x.shape)
        out = self.extra(x) + out
        out = F.relu(out)
        return out
class ResNet18(nn.Module):
    def __init__(self, ResBlk):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )

        self.outlayer = nn.Linear(512, 10)

        # followed 4 blocks
        # [b,64,h,w] => [b,128,h,w]
        self.blk1 = ResBlk(64, 128)
        self.blk2 = ResBlk(128, 256)
        self.blk3 = ResBlk(256, 512)
        self.blk4 = ResBlk(512, 1024)

    def forward(self, x):
        x = F.relu(self.conv1)
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        x = self.outlayer(x)
        return x
    def main():
       blk = Resblock(64, 128)
       tmp = torch.randn(2, 64, 32, 32)
       out = blk(tmp)
       print("block",out.shape)
    if __name__ == "__main__":
       main()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值