Resnet.Module代码解读/带看

Resnet结构图

1.Resnet.module完整代码

import torch.nn as nn
import torch

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size
### ResNet 论文详解与深度学习卷积神经网络架构解析 #### 1. 背景介绍 在深度学习领域,随着网络层数的增加,模型可能会面临梯度消失或梯度爆炸的问题。这些问题可能导致训练困难,甚至使模型性能下降。为了应对这些挑战,ResNet(残差网络)被提出作为一种有效的解决方案[^2]。 #### 2. 核心思想 ResNet 的核心在于引入了 **残差结构**(residual structure),允许网络通过跳跃连接(skip connections)直接传递输入信息到后续层。这种设计能够缓解深层网络中的退化问题(degradation problem),即当网络变得更深时,训练误差不再减小反而可能增大。通过这种方式,即使在网络非常深的情况下,模型仍然可以保持良好的收敛性和泛化能力。 #### 3. 数学表达 假设某一层的输入为 \( x \),目标是学习一个映射函数 \( H(x) \)。传统方法试图直接拟合 \( H(x) \),而 ResNet 则改为目标学习一个残差函数 \( F(x) = H(x) - x \)。因此,最终输出变为: \[ y = F(x) + x \] 其中 \( y \) 是该层的输出。这样的形式使得优化过程更加容易,因为即使 \( F(x) \) 接近零,模型也能轻松学到恒等映射 \( H(x) = x \)[^2]。 #### 4. 架构细节 ResNet 中的关键组件包括以下几个部分: - **残差块(Residual Block)** 每个残差块通常由两到三个卷积层构成,并辅以批量归一化(Batch Normalization, BN)和激活函数(ReLU)。如果输入和输出维度不匹配,则可以通过线性投影实现降维操作[^2]。 - **跳跃连接(Skip Connections)** 这些连接负责将前面的信息无损地传送到后面的层中,从而避免因过多非线性变换而导致的信息丢失。 以下是 PyTorch 实现的一个简单版本的 ResNet代码示例: ```python import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels * self.expansion) self.downsample = downsample def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out ``` #### 5. 应用场景 ResNet 不仅适用于图像分类任务,在其他多种计算机视觉任务中也有广泛应用,例如目标检测、语义分割等领域。它的成功推动了许多后续工作的诞生,比如 DenseNet 和 EfficientNet 等新型架构的设计思路都受到启发[^2]。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值