神奇宝贝五分类:网络自定义

本文介绍如何利用ResNet18深度学习模型进行神奇宝贝图片的五分类任务,详细探讨网络结构的定制过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

用resnet18的结构

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    '''
    resnet block
    '''
    def __init__(self,ch_in,ch_out,stride=1):
        '''

        :param ch_in:
        :param ch_out:
        '''
        super(ResBlk,self).__init__()

        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn1=nn.BatchNorm2d(ch_out)  # ResNet,一般都会加BatchNorm
        self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(ch_out)

        self.extra=nn.Sequential()  # 先设为空的extra,如果有的话,下面的就把这个覆盖掉
        if ch_out != ch_in:
            #把[b,ch_in,h.w]=>[b,ch_out,h,w]
            self.extra =nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride), # 保持size不变
                nn.BatchNorm2d(ch_out)
            )


    def forward(self,x):
        '''

        :param x:[b,ch,h,w]
        :return:
        '''
        # 因为后边shortcut还需要x,所以这里用out命名,不用x
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.bn2(self.conv2(out))
        # shortcut
        # extra model:[b,ch_in,h,w]=>[b.ch_out,h,w],所以相加应该保持ch_in和ch_out维度相等
        # element wise add
        out=self.extra(x)+out
        out=F.relu(out)

        return out


class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1=nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(16)
        )
        # follow 4 blocks
        # [b,16,h,w]=>[b,32,h,w]
        self.blk1=ResBlk(16,32,stride=3)
        # [b.32,h,w]=>[b,64,h,w]
        self.blk2=ResBlk(32,64,stride=3)
        # [b,64,h,w]=>[b,128,h,w]
        self.blk3=ResBlk(64,128,stride=2)
        # [b,128,h,w]=>[b,256,h,w]
        self.blk4=ResBlk(128,256,stride=2)

        self.outlayer=nn.Linear(256*1*1,5)  # 变成十类

    def forward(self, x):
        '''

        :param x:
        :return:
        '''


        x=F.relu(self.conv1(x))
        # [b,64,h,w]=>[b.1024,h,w]
        x=self.blk1(x)
        print("block3", x.shape)
        x=self.blk2(x)
        print("block4", x.shape)
        x=self.blk3(x)
        print("block5", x.shape)
        x=self.blk4(x)
        print("block6", x.shape)
        print(x.shape)

        x=F.adaptive_avg_pool2d(x,[1,1])
        print('after pool:',x.shape)
        x=x.view(x.size(0),-1)
        x=self.outlayer(x)


        return x

#用来测试
def main():

    blk=ResBlk(64,128,stride=2)  # stride可以有效实现长和宽的衰减
    tmp=torch.randn(2,64,224,224)
    out=blk(tmp)
    print('block:',out.shape)

    x=torch.randn(2,3,224,224)
    model=ResNet18()
    out=model(x)
    print('outshape',out.shape)

# 来计算一下未知的维度信息
if __name__ == '__main__':
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值