深度学习camp-第J5周:DenseNet+SE-Net实战

📌 本周任务(自己改进一下):
●1. 在DenseNet系列算法中插入SE-Net通道注意力机制,并完成乳腺癌数据集识别
●2. 改进思路是否可以迁移到其他地方呢
●3. 验证集accuracy比较

一、前言

SE-Net(Squeeze-and-Excitation Network),它是 ImageNet 2017 竞赛的冠军模型,由 WMW 团队提出。具有复杂度低,参数少和计算量小的优点。SE-Net 的设计思路简单,容易与现有的网络结构(如 Inception 和 ResNet)结合,增强这些网络的性能。传统的网络结构(如 Inception)主要关注空间维度的特征提升,而 SE-Net 则将重点放在特征通道之间的关系上。这意味着它关注的是不同特征通道的重要性,而不仅仅是空间位置。SE-Net 通过学习每个特征通道的重要性,来自动调整特征的权重。具体来说,它会提升那些对当前任务有用的特征,同时抑制那些不重要的特征。这一过程被称为“特征重标定”。SE 模块是 SE-Net 的核心组成部分,负责实现特征重标定的功能。虽然具体的 SE 模块图未显示,但通常它包括两个主要步骤:首先通过全局平均池化获取特征通道的全局信息,然后通过全连接层学习每个通道的重要性,最后根据这些重要性调整特征通道的输出。如下图所示:
在这里插入图片描述
对于给定的一个输入x,其特征通道数为C’,经过一系列操作变换后通道数变为C,SE-Net会进行如下操作

  1. Squeeze 操作: 通过全局平均池化将每个通道的特征压缩成一个单一的数值,从而得到一个全局空间信息的通道描述符。这一步可以视为对每个通道的特征进行“压缩”,从而总结出通道的全局信息。
  2. Excitation 操作: 采用一个全连接的神经网络,通常包含两层,第一层用来降维(减少模型复杂度和参数量),第二层用来恢复维度。这个过程通过 Sigmoid 函数输出每个通道的权重系数,从而实现对每个通道的“激励”。
  3. Scale 操作: 最后,通过将 Excitation 操作的输出(即通道的权重系数)与原始输入按元素相乘,实现了对特征的重新缩放。这种按权重调整通道输出的方法,增强了模型对有用特征的捕捉能力,同时抑制了不重要的特征。

二、SE模块的通用性

SE模块的一个优点在于他可以直接应用于现有的网络结构中,以Inception和ResNet为例,我们只需要在Inception模块和Residual模块后面加上SE模块即可。
在这里插入图片描述

三、SE模块的代码实现

class SELayer(nn.Module):
    def __init__(self, channels, reduction= 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1) #Squeeze操作
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid() #产生通道权重
        )
        
    def forward(self, x):
        b, c, h, w = x.size()
        y = self.avg_pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y.expand_as(x)

四、SE模块插入到DenseNet网络中

我直接把网络结构放出来:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SELayer(nn.Module):
    def __init__(self, channels, reduction= 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1) #Squeeze操作
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid() #产生通道权重
        )
        
    def forward(self, x):
        b, c, h, w = x.size()
        y = self.avg_pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y.expand_as(x)

#定义卷积块
class ConvBlock(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False),
            nn.BatchNorm2d(4 * growth_rate),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        return x

class DenseBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(ConvBlock(in_channels + i * growth_rate, growth_rate))
    
    def forward(self, x):
        for layer in self.layers:
            new_features = layer(x)
            x = torch.cat([x, new_features], dim=1)
        return x

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionBlock, self).__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv(x)
        x = self.avg_pool(x)
        return x

# 构建DenseNet
class DenseNet(nn.Module):
    def __init__(self, block_config, num_classes=1000, growth_rate=32):
        super(DenseNet, self).__init__()
        num_init_features = 64
        self.features = nn.Sequential(
            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        num_features = num_init_features
        for i, num_blocks in enumerate(block_config):
            block = DenseBlock(num_blocks, num_features, growth_rate)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_features += num_blocks * growth_rate
            if i != len(block_config) - 1:
                trans = TransitionBlock(num_features, num_features // 2)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_features = num_features // 2
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))
        self.classifier = nn.Linear(num_features, num_classes)
        self.SE_layer = SELayer(num_features)  # SE Layer should be initialized with the correct number of channels

    def forward(self, x):
        x = self.features(x)
        x = self.SE_layer(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    
# 定义不同的 DenseNet 配置
def DenseNet121(num_classes=3):
    return DenseNet([6, 12, 24, 16], num_classes=num_classes)

def DenseNet169(num_classes=3):
    return DenseNet([6, 12, 32, 32], num_classes=num_classes)

def DenseNet201(num_classes=3):
    return DenseNet([6, 12, 48, 32], num_classes=num_classes)

import torchsummary
model1 = DenseNet121().cuda()
x = (3, 224, 224)

torchsummary.summary(model1,x)
model2 = DenseNet169().cuda()
model3 = DenseNet201().cuda()

代码输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
              ReLU-6           [-1, 64, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           8,192
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10           [-1, 32, 56, 56]          36,864
        ConvBlock-11           [-1, 32, 56, 56]               0
      BatchNorm2d-12           [-1, 96, 56, 56]             192
             ReLU-13           [-1, 96, 56, 56]               0
           Conv2d-14          [-1, 128, 56, 56]          12,288
      BatchNorm2d-15          [-1, 128, 56, 56]             256
             ReLU-16          [-1, 128, 56, 56]               0
           Conv2d-17           [-1, 32, 56, 56]          36,864
        ConvBlock-18           [-1, 32, 56, 56]               0
      BatchNorm2d-19          [-1, 128, 56, 56]             256
             ReLU-20          [-1, 128, 56, 56]               0
           Conv2d-21          [-1, 128, 56, 56]          16,384
      BatchNorm2d-22          [-1, 128, 56, 56]             256
             ReLU-23          [-1, 128, 56, 56]               0
           Conv2d-24           [-1, 32, 56, 56]          36,864
        ConvBlock-25           [-1, 32, 56, 56]               0
      BatchNorm2d-26          [-1, 160, 56, 56]             320
             ReLU-27          [-1, 160, 56, 56]               0
           Conv2d-28          [-1, 128, 56, 56]          20,480
      BatchNorm2d-29          [-1, 128, 56, 56]             256
             ReLU-30          [-1, 128, 56, 56]               0
           Conv2d-31           [-1, 32, 56, 56]          36,864
        ConvBlock-32           [-1, 32, 56, 56]               0
      BatchNorm2d-33          [-1, 192, 56, 56]             384
             ReLU-34          [-1, 192, 56, 56]               0
           Conv2d-35          [-1, 128, 56, 56]          24,576
      BatchNorm2d-36          [-1, 128, 56, 56]             256
             ReLU-37          [-1, 128, 56, 56]               0
           Conv2d-38           [-1, 32, 56, 56]          36,864
        ConvBlock-39           [-1, 32, 56, 56]               0
      BatchNorm2d-40          [-1, 224, 56, 56]             448
             ReLU-41          [-1, 224, 56, 56]               0
           Conv2d-42          [-1, 128, 56, 56]          28,672
      BatchNorm2d-43          [-1, 128, 56, 56]             256
             ReLU-44          [-1, 128, 56, 56]               0
           Conv2d-45           [-1, 32, 56, 56]          36,864
        ConvBlock-46           [-1, 32, 56, 56]               0
       DenseBlock-47          [-1, 256, 56, 56]               0
      BatchNorm2d-48          [-1, 256, 56, 56]             512
             ReLU-49          [-1, 256, 56, 56]               0
           Conv2d-50          [-1, 128, 56, 56]          32,768
        AvgPool2d-51          [-1, 128, 28, 28]               0
  TransitionBlock-52          [-1, 128, 28, 28]               0
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
           Conv2d-55          [-1, 128, 28, 28]          16,384
      BatchNorm2d-56          [-1, 128, 28, 28]             256
             ReLU-57          [-1, 128, 28, 28]               0
           Conv2d-58           [-1, 32, 28, 28]          36,864
        ConvBlock-59           [-1, 32, 28, 28]               0
      BatchNorm2d-60          [-1, 160, 28, 28]             320
             ReLU-61          [-1, 160, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]          20,480
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65           [-1, 32, 28, 28]          36,864
        ConvBlock-66           [-1, 32, 28, 28]               0
      BatchNorm2d-67          [-1, 192, 28, 28]             384
             ReLU-68          [-1, 192, 28, 28]               0
           Conv2d-69          [-1, 128, 28, 28]          24,576
      BatchNorm2d-70          [-1, 128, 28, 28]             256
             ReLU-71          [-1, 128, 28, 28]               0
           Conv2d-72           [-1, 32, 28
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值