GroupMamba解析及测试

GroupMamba解析

Shaker A, Wasim S T, Khan S, et al. GroupMamba: Parameter-Efficient and Accurate Group Visual State Space Model[J]. arXiv preprint arXiv:2407.13772, 2024.

1.论文解析

本文核心思想:VMamba四向扫描和SeNet注意力机制结合,结构参考Swin-Transformer方式。

初衷是解决基于状态空间模型(SSMs)在计算机视觉任务中的稳定性和效率问题。通过创新的调制组 Mamba 层、通道亲和调制算子和蒸馏训练目标,在多个视觉任务上取得了优异性能。

1.1 研究背景与动机

  • 状态空间模型在处理长程依赖方面有潜力,但在计算机视觉任务中,基于 SSMs 的模型面临稳定性和性能优化挑战。例如 Mamba 模型在图像分类等任务中,当参数规模扩大时不稳定,且其 VSS 块在通道维度上的参数和计算复杂度较高。
  • 卷积神经网络、视觉 Transformer 等方法在计算机视觉领域不断发展,但仍存在改进空间,促使研究人员探索新的模型结构。

1.2 相关工作

  • 卷积神经网络(ConvNets)经历了多次架构演进,ConvNeXt 等变体在性能上不断提升。
  • 视觉 Transformer(ViT)及其衍生模型对计算机视觉任务产生了重要影响,同时也引发了对注意力机制复杂度的研究,推动了高效变体的发展。
  • 状态空间模型(SSMs)逐渐成为 ViT 的替代方案,如 S4、Mamba 等模型被提出,在视觉领域也有了多种应用,但仍需解决一些固有问题。

1.3 方法

image-20241227091544964

  1. 整体架构:采用类似 Swin Transformer 的分层架构,分为四个阶段。输入图像先经过 Patch Embedding 层划分为不重叠的补丁并嵌入特征向量,然后依次通过多个调制组 Mamba 块和下采样层进行处理。
  2. 调制组 Mamba 层
    • VSSS 块:基于 Mamba 算子的令牌和通道混合器,包含 Mamba 操作和前馈网络(FFN),对输入令牌序列进行处理。
    • 分组 Mamba 操作:受分组卷积启发,将输入通道分为四组,每组应用独立的 VSSS 块,并在四个不同方向进行扫描,以更好地建模空间依赖关系,最后将结果拼接。(就是VMama,从代码上也是)
    • 通道亲和调制(CAM):为解决分组操作导致的通道间信息交换有限问题,通过平均池化计算通道统计信息,经亲和计算后重新校准分组 Mamba 算子的输出。(本质就是SENET,换了个名字)
  3. 蒸馏损失函数:针对 Mamba 训练在大模型时不稳定的问题,采用蒸馏目标与标准交叉熵目标相结合的方式。通过最小化学生模型与教师模型(RegNetY - 16G)的分类损失和蒸馏损失之和来训练模型,稳定训练过程并提升性能。

1.4 实验结果

  • 图像分类:在 ImageNet-1K 数据集上,GroupMamba-T、-S、-B 模型分别在不同参数规模下取得了较高的准确率,如 GroupMamba-T 以 2300 万参数和 4.5 GFLOPs 实现了 83.3% 的 top - 1 准确率,优于 ConvNeXt-T、Swin-T 等模型,且相比同类 SSM 模型 VMamba-T 参数减少 26%。

  • 对象检测和实例分割:在 MS - COCO 2017 数据集上,基于 Mask-RCNN 框架,GroupMamb-T 模型的 box AP 为 47.6,mask AP 为 42.9,超越了 ResNet-50、Swin-T 等模型,且参数比 VMamba-T 少 20%。

  • 语义分割:在 ADE20K 数据集上,GroupMamba-T 模型在单尺度和多尺度评估下的 mIoU 分别为 48.6 和 49.2,以 4900 万参数和 955G FLOPs 超越了 ResNet-50、Swin-T 等模型及相关 SSM 方法。

  • 消融研究:实验表明 CAM 模块和蒸馏损失函数均对模型性能有提升作用,如 GroupMamba - T 在加入 CAM 模块后准确率提升 0.3%,加入蒸馏损失函数且不增加通道数时,相比 VMamba - T 性能提升 0.8% 且参数减少 26%。

    image-20241227091923276

2.源码解析

2.1 GroupMamba.py

核心代码在models/groupmamba.py

FFN类

就是一个普通的FC-GeLu-FC模块

class FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

PVT2FFN

就是一个FC-分组卷积-GELU-FC模块

class PVT2FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.fc2(x)
        return x

GroupMambaLayer

这是论文的核心模块,实现的就是下图,类似于SE和Mamba结合。输入进来后分为三个并行分支,每个分支输入相同,分支1输入沿通道切分为四份,分别经过不同的Mamba沿着不同方向扫描处理后合并;分支2就是SE Net;分支1和分之2完成注意力加权;分支3是残差连接。

image-20241225153357873

class GroupMambaLayer(nn.Module):
    def __init__(self, input_dim, output_dim, d_state=1, d_conv=3, expand=1, reduction=16):
        super().__init__()
        # 计算降维后的通道数
        num_channels_reduced = input_dim // reduction
        # 定义第一个全连接层,将输入维度映射到降维后的通道数,包含偏置
        self.fc1 = nn.Linear(input_dim, num_channels_reduced, bias=True)
        # 定义第二个全连接层,将降维后的通道数映射回输出维度,包含偏置
        self.fc2 = nn.Linear(num_channels_reduced, output_dim, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        # 保存输入和输出维度
        self.input_dim = input_dim
        self.output_dim = output_dim
        # 定义层归一化,应用于输入维度
        self.norm = nn.LayerNorm(input_dim)
        
        # 定义四个 SS2D 模块,每个模块维度为输入维度的四分之一
        self.mamba_g1 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        self.mamba_g2 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        self.mamba_g3 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        self.mamba_g4 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        # 定义投影层,将输入维度映射到输出维度
        self.proj = nn.Linear(input_dim, output_dim)
        # 定义一个可训练的跳跃连接缩放参数,初始值为1
        self.skip_scale = nn.Parameter(torch.ones(1))

    def forward(self, x, H, W):
        # 如果输入张量的数据类型是 float16,则转换为 float32
        if x.dtype == torch.float16:
            x = x.type(torch.float32)
        # 获取输入张量的批量大小 B、序列长度 N 和通道数 C
        B, N, C = x.shape
        # 对输入张量进行层归一化
        x = self.norm(x)

        # Channel Affinity
        # [B, N, C] -> [B, C, N]->[B,C]
        # 将输入张量 [B, N, C] 转换为 [B, C, N],
        # 然后在序列维度上取均值,得到 [B, C]
        z = x.permute(0, 2, 1).mean(dim=2)
        
        # [B, C]->[B,output_dim]
        # 通过第一个全连接层,并应用 ReLU 激活,得到 [B, output_dim]
        fc_out_1 = self.relu(self.fc1(z))
        # 通过第二个全连接层,并应用 Sigmoid 激活,得到 [B, output_dim]
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

        # 将输入张量重新排列为 [B, H, W, C]
        x = rearrange(x, 'b (h w) c -> b h w c', b=B, h=H, w=W, c=C)
        # 将通道维度均分为四部分,分别赋值给 x1, x2, x3, x4
        x1, x2, x3, x4 = torch.chunk(x, 4, dim=-1)

        # 对四个不同方向应用四个 SS2D 扫描,每个扫描应用于 N/4 个通道
        x_mamba1 = self.mamba_g1(x1, CrossScan=CrossScan_1, CrossMerge=CrossMerge_1)
        x_mamba2 = self.mamba_g2(x2, CrossScan=CrossScan_2, CrossMerge=CrossMerge_2)
        x_mamba3 = self.mamba_g3(x3, CrossScan=CrossScan_3, CrossMerge=CrossMerge_3)
        x_mamba4 = self.mamba_g4(x4, CrossScan=CrossScan_4, CrossMerge=CrossMerge_4)

        # 将所有特征图在通道维度上拼接,并乘以跳跃连接缩放参数和原始输入 x
        x_mamba = torch.cat([x_mamba1, x_mamba2, x_mamba3, x_mamba4], dim=-1) * self.skip_scale * x

        # 将拼接后的特征图重新排列回 [B, N, C]
        x_mamba = rearrange(x_mamba, 'b h w c -> b (h w) c', b=B, h=H, w=W, c=C)

        # Channel Modulation
        # x_mamba:[B,N,C]  fc_out_2.unsqueeze:[B,1,C]
        # 将 x_mamba 与 fc_out_2 在通道维度上相乘,
        # fc_out_2 通过 unsqueeze 扩展维度后变为 [B, 1, C]
        x_mamba = x_mamba * fc_out_2.unsqueeze(1)
        # 对调制后的特征图进行层归一化
        x_mamba = self.norm(x_mamba)
        # 通过投影层,将特征图映射到输出维度
        x_mamba = self.proj(x_mamba)
        # 返回最终输出
        return x_mamba

上面代码中的CrossScan_1, CrossScan_2, CrossScan_3, CrossScan_4,CrossMerge_1, CrossMerge_2, CrossMerge_3, CrossMerge_4在csms6s.py中定义,见2.2

ClassBlock

类别token处理

class ClassBlock(nn.Module):
    def __init__(self, dim,  mlp_ratio, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = GroupMambaLayer(dim, dim)
        self.mlp = FFN(dim, int(dim * mlp_ratio))
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值