GroupMamba解析及测试

GroupMamba解析

Shaker A, Wasim S T, Khan S, et al. GroupMamba: Parameter-Efficient and Accurate Group Visual State Space Model[J]. arXiv preprint arXiv:2407.13772, 2024.

1.论文解析

本文核心思想:VMamba四向扫描和SeNet注意力机制结合,结构参考Swin-Transformer方式。

初衷是解决基于状态空间模型(SSMs)在计算机视觉任务中的稳定性和效率问题。通过创新的调制组 Mamba 层、通道亲和调制算子和蒸馏训练目标,在多个视觉任务上取得了优异性能。

1.1 研究背景与动机

  • 状态空间模型在处理长程依赖方面有潜力,但在计算机视觉任务中,基于 SSMs 的模型面临稳定性和性能优化挑战。例如 Mamba 模型在图像分类等任务中,当参数规模扩大时不稳定,且其 VSS 块在通道维度上的参数和计算复杂度较高。
  • 卷积神经网络、视觉 Transformer 等方法在计算机视觉领域不断发展,但仍存在改进空间,促使研究人员探索新的模型结构。

1.2 相关工作

  • 卷积神经网络(ConvNets)经历了多次架构演进,ConvNeXt 等变体在性能上不断提升。
  • 视觉 Transformer(ViT)及其衍生模型对计算机视觉任务产生了重要影响,同时也引发了对注意力机制复杂度的研究,推动了高效变体的发展。
  • 状态空间模型(SSMs)逐渐成为 ViT 的替代方案,如 S4、Mamba 等模型被提出,在视觉领域也有了多种应用,但仍需解决一些固有问题。

1.3 方法

image-20241227091544964

  1. 整体架构:采用类似 Swin Transformer 的分层架构,分为四个阶段。输入图像先经过 Patch Embedding 层划分为不重叠的补丁并嵌入特征向量,然后依次通过多个调制组 Mamba 块和下采样层进行处理。
  2. 调制组 Mamba 层
    • VSSS 块:基于 Mamba 算子的令牌和通道混合器,包含 Mamba 操作和前馈网络(FFN),对输入令牌序列进行处理。
    • 分组 Mamba 操作:受分组卷积启发,将输入通道分为四组,每组应用独立的 VSSS 块,并在四个不同方向进行扫描,以更好地建模空间依赖关系,最后将结果拼接。(就是VMama,从代码上也是)
    • 通道亲和调制(CAM):为解决分组操作导致的通道间信息交换有限问题,通过平均池化计算通道统计信息,经亲和计算后重新校准分组 Mamba 算子的输出。(本质就是SENET,换了个名字)
  3. 蒸馏损失函数:针对 Mamba 训练在大模型时不稳定的问题,采用蒸馏目标与标准交叉熵目标相结合的方式。通过最小化学生模型与教师模型(RegNetY - 16G)的分类损失和蒸馏损失之和来训练模型,稳定训练过程并提升性能。

1.4 实验结果

  • 图像分类:在 ImageNet-1K 数据集上,GroupMamba-T、-S、-B 模型分别在不同参数规模下取得了较高的准确率,如 GroupMamba-T 以 2300 万参数和 4.5 GFLOPs 实现了 83.3% 的 top - 1 准确率,优于 ConvNeXt-T、Swin-T 等模型,且相比同类 SSM 模型 VMamba-T 参数减少 26%。

  • 对象检测和实例分割:在 MS - COCO 2017 数据集上,基于 Mask-RCNN 框架,GroupMamb-T 模型的 box AP 为 47.6,mask AP 为 42.9,超越了 ResNet-50、Swin-T 等模型,且参数比 VMamba-T 少 20%。

  • 语义分割:在 ADE20K 数据集上,GroupMamba-T 模型在单尺度和多尺度评估下的 mIoU 分别为 48.6 和 49.2,以 4900 万参数和 955G FLOPs 超越了 ResNet-50、Swin-T 等模型及相关 SSM 方法。

  • 消融研究:实验表明 CAM 模块和蒸馏损失函数均对模型性能有提升作用,如 GroupMamba - T 在加入 CAM 模块后准确率提升 0.3%,加入蒸馏损失函数且不增加通道数时,相比 VMamba - T 性能提升 0.8% 且参数减少 26%。

    image-20241227091923276

2.源码解析

2.1 GroupMamba.py

核心代码在models/groupmamba.py

FFN类

就是一个普通的FC-GeLu-FC模块

class FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

PVT2FFN

就是一个FC-分组卷积-GELU-FC模块

class PVT2FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.fc2(x)
        return x

GroupMambaLayer

这是论文的核心模块,实现的就是下图,类似于SE和Mamba结合。输入进来后分为三个并行分支,每个分支输入相同,分支1输入沿通道切分为四份,分别经过不同的Mamba沿着不同方向扫描处理后合并;分支2就是SE Net;分支1和分之2完成注意力加权;分支3是残差连接。

image-20241225153357873

class GroupMambaLayer(nn.Module):
    def __init__(self, input_dim, output_dim, d_state=1, d_conv=3, expand=1, reduction=16):
        super().__init__()
        # 计算降维后的通道数
        num_channels_reduced = input_dim // reduction
        # 定义第一个全连接层,将输入维度映射到降维后的通道数,包含偏置
        self.fc1 = nn.Linear(input_dim, num_channels_reduced, bias=True)
        # 定义第二个全连接层,将降维后的通道数映射回输出维度,包含偏置
        self.fc2 = nn.Linear(num_channels_reduced, output_dim, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        # 保存输入和输出维度
        self.input_dim = input_dim
        self.output_dim = output_dim
        # 定义层归一化,应用于输入维度
        self.norm = nn.LayerNorm(input_dim)
        
        # 定义四个 SS2D 模块,每个模块维度为输入维度的四分之一
        self.mamba_g1 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        self.mamba_g2 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        self.mamba_g3 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        self.mamba_g4 = SS2D(
            d_model=input_dim // 4,
            d_state=d_state,
            ssm_ratio=expand,
            d_conv=d_conv
        )
        # 定义投影层,将输入维度映射到输出维度
        self.proj = nn.Linear(input_dim, output_dim)
        # 定义一个可训练的跳跃连接缩放参数,初始值为1
        self.skip_scale = nn.Parameter(torch.ones(1))

    def forward(self, x, H, W):
        # 如果输入张量的数据类型是 float16,则转换为 float32
        if x.dtype == torch.float16:
            x = x.type(torch.float32)
        # 获取输入张量的批量大小 B、序列长度 N 和通道数 C
        B, N, C = x.shape
        # 对输入张量进行层归一化
        x = self.norm(x)

        # Channel Affinity
        # [B, N, C] -> [B, C, N]->[B,C]
        # 将输入张量 [B, N, C] 转换为 [B, C, N],
        # 然后在序列维度上取均值,得到 [B, C]
        z = x.permute(0, 2, 1).mean(dim=2)
        
        # [B, C]->[B,output_dim]
        # 通过第一个全连接层,并应用 ReLU 激活,得到 [B, output_dim]
        fc_out_1 = self.relu(self.fc1(z))
        # 通过第二个全连接层,并应用 Sigmoid 激活,得到 [B, output_dim]
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

        # 将输入张量重新排列为 [B, H, W, C]
        x = rearrange(x, 'b (h w) c -> b h w c', b=B, h=H, w=W, c=C)
        # 将通道维度均分为四部分,分别赋值给 x1, x2, x3, x4
        x1, x2, x3, x4 = torch.chunk(x, 4, dim=-1)

        # 对四个不同方向应用四个 SS2D 扫描,每个扫描应用于 N/4 个通道
        x_mamba1 = self.mamba_g1(x1, CrossScan=CrossScan_1, CrossMerge=CrossMerge_1)
        x_mamba2 = self.mamba_g2(x2, CrossScan=CrossScan_2, CrossMerge=CrossMerge_2)
        x_mamba3 = self.mamba_g3(x3, CrossScan=CrossScan_3, CrossMerge=CrossMerge_3)
        x_mamba4 = self.mamba_g4(x4, CrossScan=CrossScan_4, CrossMerge=CrossMerge_4)

        # 将所有特征图在通道维度上拼接,并乘以跳跃连接缩放参数和原始输入 x
        x_mamba = torch.cat([x_mamba1, x_mamba2, x_mamba3, x_mamba4], dim=-1) * self.skip_scale * x

        # 将拼接后的特征图重新排列回 [B, N, C]
        x_mamba = rearrange(x_mamba, 'b h w c -> b (h w) c', b=B, h=H, w=W, c=C)

        # Channel Modulation
        # x_mamba:[B,N,C]  fc_out_2.unsqueeze:[B,1,C]
        # 将 x_mamba 与 fc_out_2 在通道维度上相乘,
        # fc_out_2 通过 unsqueeze 扩展维度后变为 [B, 1, C]
        x_mamba = x_mamba * fc_out_2.unsqueeze(1)
        # 对调制后的特征图进行层归一化
        x_mamba = self.norm(x_mamba)
        # 通过投影层,将特征图映射到输出维度
        x_mamba = self.proj(x_mamba)
        # 返回最终输出
        return x_mamba

上面代码中的CrossScan_1, CrossScan_2, CrossScan_3, CrossScan_4,CrossMerge_1, CrossMerge_2, CrossMerge_3, CrossMerge_4在csms6s.py中定义,见2.2

ClassBlock

类别token处理

class ClassBlock(nn.Module):
    def __init__(self, dim,  mlp_ratio, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = GroupMambaLayer(dim, dim)
        self.mlp = FFN(dim, int(dim * mlp_ratio))
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        # X:[B, N, C] x[:, :1]:提取出每个样本的第一个 Token
        # cls_embed:[B,1,C] 
        # 获取类别嵌入,提取输入张量的第一个 token
        cls_embed = x[:, :1]
        # 通过注意力层处理分类嵌入,并与原始分类嵌入相加
        cls_embed = cls_embed + self.norm1(self.attn(x[:, :1], H, W))
        cls_embed = cls_embed + self.mlp(self.norm2(cls_embed), H, W)
        return torch.cat([cls_embed, x[:, 1:]], dim=1)

cls_embed 的作用:在许多基于 Transformer 的模型(例如 BERT、ViT 等)中,第一个 Token 通常被设计为分类 Token(Classification Token),简称 cls。这个 cls Token 的输出会被用作整个序列的表示,用于下游的分类任务。通过 cls_embed = x[:, :1],我们提取出这个分类 Token,以便后续进行处理(如通过注意力层和前馈网络进行特征增强)。

ClassBlockforward 方法中,cls_embed 被进一步处理,通过归一化层 self.norm1 和注意力层 self.attncls_embed 进行处理,并将处理后的结果与原始的 cls_embed 相加,实现特征的增强和融合。最终,将处理后的 cls_embed 与输入张量中除第一个 Token 以外的其他 Token 拼接在一起,形成最终的输出。

Block_mamba
常规Mamba块

class Block_mamba(nn.Module):
    def __init__(self, 
        dim, 
        mlp_ratio,
        drop_path=0., 
        norm_layer=nn.LayerNorm
    ):
        super().__init__()
        self.norm2 = norm_layer(dim)

        self.attn = GroupMambaLayer(dim, dim)
        self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        # 将输入张量 x 通过注意力层处理,并应用 DropPath 后与原始输入相加
        x = x + self.drop_path(self.attn(x, H, W))
        # 对加和后的结果进行归一化,然后通过前馈网络处理,并应用 DropPath 后与原始输入相加
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
        return x

DownSamples

使用卷积完成特征图尺寸减半,并调整通道顺序,使其符合Mamba [B,N,C],通道在最后的输入要求。

class DownSamples(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm = nn.LayerNorm(out_channels)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        # 通过卷积层进行下采样,输出特征图大小减半
        x = self.proj(x)
        # 获取下采样后特征图的高度和宽度
        _, _, H, W = x.shape
        # 将特征图展平,从第三维开始展平,并转置,使得形状为 [B, H*W, C]
        x = x.flatten(2).transpose(1, 2)
        # 对展平后的特征进行层归一化
        x = self.norm(x)
        # 返回归一化后的特征以及高度和宽度
        return x, H, W

Stem

完成初步特征提取,对输入进行4倍下采样,并调整shape为[B,N,C],符合Mamba输入

class Stem(nn.Module):
    def __init__(self, in_channels, stem_hidden_dim, out_channels):
        super().__init__()
        hidden_dim = stem_hidden_dim
        self.conv = nn.Sequential(
            # 第一个卷积层,使用 7x7 卷积核,实现2倍下采样
            nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
                      padding=3, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            # 第二个卷积层,使用 3x3 卷积核,不改变尺寸
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            # 第三个卷积层,使用 3x3 卷积核,不改变尺寸
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        # 2倍下采样
        self.proj = nn.Conv2d(hidden_dim,
                              out_channels,
                              kernel_size=3,
                              stride=2,
                              padding=1)
        self.norm = nn.LayerNorm(out_channels)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        # 通过卷积层序列进行特征提取和初步处理,尺寸减半
        x = self.conv(x)
        # 通过卷积层进行下采样,输出特征图大小减半,输出通道数转换为 out_channels
        x = self.proj(x)
        _, _, H, W = x.shape
        # 将特征图展平,从第三维开始展平,并转置,使得形状为 [B, H*W, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

GroupMamba

这个是核心,组合了前面所有

class GroupMamba(nn.Module):
    def __init__(self, 
        in_chans=3, 
        num_classes=1000, 
        stem_hidden_dim = 32,
        embed_dims=[64, 128, 348, 448],
        mlp_ratios=[8, 8, 4, 4], 
        drop_path_rate=0., 
        norm_layer=nn.LayerNorm,
        depths=[3, 4, 6, 3], # 每个阶段的深度(Block 的数量)
        num_stages=4, # 阶段数量
        distillation=True,
        **kwargs
    ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages
        # 生成 DropPath 的概率列表,线性递增
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0 # 初始化当前深度索引
        for i in range(num_stages): # 遍历每个阶段
            # 如果是第一个阶段,使用 Stem 模块进行特征提取
            if i == 0: # [3,224,224] -> [64*64,64]
                patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
            else: # 否则,使用 DownSamples 模块进行下采样
                patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])
            # 为当前阶段创建 Block_mamba 模块列表
            block = nn.ModuleList([Block_mamba(
                    dim = embed_dims[i],
                    mlp_ratio = mlp_ratios[i],
                    drop_path=dpr[cur + j],
                    norm_layer=norm_layer)
                for j in range(depths[i])])
            # 为当前阶段创建归一化层
            norm = norm_layer(embed_dims[i])
            # 更新当前深度索引
            cur += depths[i]
            # 动态添加 patch_embed 模块到类中
            setattr(self, f"patch_embed{i + 1}", patch_embed)
            # 动态添加 block 模块到类中
            setattr(self, f"block{i + 1}", block)
            # 动态添加归一化层到类中
            setattr(self, f"norm{i + 1}", norm)
        # 定义后处理层列表,这里仅包含 'ca' 一个后处理层
        post_layers = ['ca']
        self.post_network = nn.ModuleList([
            # 创建后处理网络模块列表,使用 ClassBlock
            ClassBlock(
                dim = embed_dims[-1], 
                mlp_ratio = mlp_ratios[-1],
                norm_layer=norm_layer)
            for _ in range(len(post_layers))
        ])

        # classification head
        # 定义分类头,如果类别数量大于0,则使用全连接层,否则使用恒等映射
        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()

        # distillation head
        # 保存是否使用蒸馏
        self.dist = distillation
        if self.dist:  # 如果使用蒸馏,定义蒸馏头cc
            self.dist_head = nn.Linear(
                embed_dims[-1], num_classes) if num_classes > 0 \
                else nn.Identity()
        # 如果使用蒸馏,定义蒸馏头
        self.apply(self._init_weights)

    # 定义权重初始化的方法
    def _init_weights(self, m):
        # 如果模块是全连接层
        if isinstance(m, nn.Linear):
            # 使用截断正态分布初始化权重,标准差为 0.02
            trunc_normal_(m.weight, std=.02)
            # 如果全连接层有偏置,则将偏置初始化为 0
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        # 如果模块是层归一化
        elif isinstance(m, nn.LayerNorm):
            # 将偏置初始化为 0
            nn.init.constant_(m.bias, 0)
            # 将权重初始化为 1.0
            nn.init.constant_(m.weight, 1.0)
        # 如果模块是二维卷积层
        elif isinstance(m, nn.Conv2d):
            # 计算 fan_out,等于卷积核尺寸乘以输出通道数
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            # 除以组数,得到平均 fan_out
            fan_out //= m.groups
            # 使用正态分布初始化权重,标准差为 sqrt(2.0 / fan_out)
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            # 如果卷积层有偏置,则将偏置初始化为 0
            if m.bias is not None:
                m.bias.data.zero_()

    # 定义分类前向传播方法
    def forward_cls(self, x, H, W):
        # 计算类别 token 的嵌入,通过对序列维度取均值
        cls_tokens = x.mean(dim=1, keepdim=True)
        # 将类别嵌入与原始输入拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        for block in self.post_network:
            # 遍历后处理网络中的每个模块,进行处理
            x = block(x, H, W)
        return x

    # 定义特征提取的前向传播方法
    def forward_features(self, x):
        B = x.shape[0] # 获取批量大小
        # 遍历每个阶段,进行特征提取和处理
        for i in range(self.num_stages):
            # 获取当前阶段的 patch_embed 模块
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            # 获取当前阶段的 block 模块列表
            block = getattr(self, f"block{i + 1}")
            # 通过 patch_embed 模块进行嵌入和下采样
            x, H, W = patch_embed(x)
            # 遍历当前阶段的每个 Block_mamba 模块,进行处理
            for blk in block:
                x = blk(x, H, W)
            # 如果不是最后一个阶段,进行归一化和维度变换
            if i != self.num_stages - 1:
                # 获取当前阶段的归一化层
                norm = getattr(self, f"norm{i + 1}")
                # 对特征进行归一化
                x = norm(x)
                # 重塑特征图形状,并进行转置,使其适应后续的卷积层
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        # 通过分类前向传播方法处理特征,并提取第一个 token
        # cls_tokens :[B, 1, C] 与原始特征x在dim=1拼接,得到[B, N+1, C]
        # self.forward_cls : [B, N+1, C] 
        # [:, 0]:提取第一个 token 的特征向量,形状为 [B, C]
        x = self.forward_cls(x, 1, 1)[:, 0]
        # 获取最后一个阶段的归一化层
        norm = getattr(self, f"norm{self.num_stages}")
        # 对特征进行归一化
        x = norm(x)
        # 返回归一化后的特征
        return x

    # 定义整体前向传播方法
    def forward(self, x):
        # 提取特征
        x = self.forward_features(x)
        if self.dist:# 如果使用蒸馏
            # 通过分类头和蒸馏头获取输出
            cls_out = self.head(x), self.dist_head(x)
            # 如果不是训练模式,合并两个输出
            if not self.training:
                cls_out = (cls_out[0] + cls_out[1]) / 2
        else: # 如果不使用蒸馏,只通过分类头获取输出
            cls_out = self.head(x)

        return cls_out

DWConv

深度可分离卷积层

class DWConv(nn.Module):
   
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        # 获取输入张量的批量大小 B、序列长度 N 和通道数 C
        B, N, C = x.shape
        # 转置张量,将通道维度移到第二维,并重塑为 [B, C, H, W]
        x = x.transpose(1, 2).view(B, C, H, W)
        # 通过深度可分离卷积层进行空间特征提取
        x = self.dwconv(x)
        # 将卷积后的张量展平,从第三维开始展平,并转置回 [B, N, C] 的形状
        x = x.flatten(2).transpose(1, 2)
        return x

后面是三个不同体量的模型版本

@register_model
def groupmamba_tiny(pretrained=False, **kwargs):
    model = GroupMamba(
        stem_hidden_dim = 32,
        embed_dims = [64, 128, 348, 448], 
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 4, 9, 3],
        **kwargs)
    model.default_cfg = _cfg()
    return model

@register_model
def groupmamba_small(pretrained=False, **kwargs):
    model = GroupMamba(
        stem_hidden_dim = 64,
        embed_dims = [64, 128, 348, 512], 
        mlp_ratios = [8, 8, 4, 4], 
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 4, 16, 3],
        **kwargs)
    model.default_cfg = _cfg()
    return model

@register_model
def groupmamba_base(pretrained=False, **kwargs):
    model = GroupMamba(
        stem_hidden_dim = 64,
        embed_dims = [96, 192, 424, 512],
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 6, 21, 3],
        **kwargs)
    model.default_cfg = _cfg()
    return model

2.2 csms6s.py

定义了VMamba: Visual State Space Model

CrossScan

实现很巧妙,通过翻转实现不同方向的扫描

class CrossScan(torch.autograd.Function):
    # 自定义的自动微分函数,继承自 torch.autograd.Function,
    # 用于实现特定的前向和反向传播操作
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        # 将输入张量的形状信息保存到上下文 ctx 中,以便在反向传播时使用
        ctx.shape = (B, C, H, W)
        # 创建一个新的空张量 xs,形状为 [B, 4, C, H * W]
        xs = x.new_empty((B, 4, C, H * W))
        # 将输入张量 x 在高度和宽度维度上展平,得到形状为 [B, C, H * W]
        xs[:, 0] = x.flatten(2, 3)
        # 将输入张量 x 在高度和宽度维度上进行转置,然后展平,得到形状为 [B, C, H * W]
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        # 对前两个切片进行翻转,得到形状为 [B, 2, C, H * W]
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        """
                反向传播方法,用于计算输入张量 x 的梯度。

                参数:
                    ctx: 上下文对象,包含前向传播中保存的信息。
                    ys: 上游的梯度,形状为 [B, 4, C, H * W]

                返回:
                    y: 输入张量 x 的梯度,形状为 [B, C, H, W]
        """
        # 从上下文中恢复输入张量 x 的形状信息
        B, C, H, W = ctx.shape
        # 计算 H * W 的值
        L = H * W
        # 将上游梯度 ys 切分为前两个和后两个部分,并进行翻转和求和
        # ys[:, 0:2] 的形状为 [B, 2, C, H * W]
        # ys[:, 2:4] 先沿最后一个维度翻转,然后与 ys[:, 0:2] 相加
        ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        
        # 将处理后的梯度 ys 切分为两部分
        # ys[:, 0] 的形状为 [B, C, H * W]
        # ys[:, 1] 先重塑为 [B, C, W, H],然后转置为 [B, C, H, W],最后展平为 [B, C, H * W]
        y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        # 将梯度 y 从 [B, C, H * W] 重塑为 [B, C, H, W]
        return y.view(B, -1, H, W)

核心就是:

        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        xs = xs.view(B, 4, C, H, W)

举例说明:

假设输入张量 x 如下(简化为 1 个样本,1 个通道,2x2 图像)

x = [
  [
    [[a, b],
     [c, d]]
  ]
]  # Shape: [1, 1, 2, 2]

经过前向传播后的 xs 各个切片如下:

原始展平特征 (xs[:, 0]):

[
  [
    [a, b, c, d]
  ]
]  # Shape: [1, 1, 4]

转置后展平特征 (xs[:, 1]):

[
  [
    [a, c, b, d]
  ]
]  # Shape: [1, 1, 4]

原始展平特征的翻转 (xs[:, 2]):

[
  [
    [d, c, b, a]
  ]
]  # Shape: [1, 1, 4]

转置后展平特征的翻转 (xs[:, 3]):

[
  [
    [d, b, c, a]
  ]
]  # Shape: [1, 1, 4]

即实现的就是如下图的四种不同方向的扫描:

image-20241226144434500

CrossMerge

用于特征融合

前向传播:

  • 将输入张量 ys进行四种不同的特征变换:
    • 原始展平: 直接展平高度和宽度维度。
    • 转置后展平: 先转置高度和宽度,再展平。
    • 翻转原始展平: 对原始展平结果进行翻转。
    • 翻转转置展平: 对转置后展平结果进行翻转。
    • 通过相加操作融合这些变换后的特征,得到最终的输出 y

反向传播:

  • 将上游梯度 x 进行逆向操作,恢复输入张量 ys 的梯度,确保梯度的正确传播。
class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        # 获取输入张量 ys 的形状,分别为批量大小 B、组数 K、维度 D、高度 H 和宽度 W
        B, K, D, H, W = ys.shape
        # 将高度和宽度信息保存到上下文 ctx 中,以便在反向传播时使用
        ctx.shape = (H, W)
        
        # 重塑 ys 张量的形状,将高度和宽度展平为一个维度,得到形状 [B, K, D, H * W]
        ys = ys.view(B, K, D, -1)
        
        # 对 ys 张量进行切片和翻转操作
        # ys[:, 0:2] 选择前两组,形状为 [B, 2, D, H * W]
        # ys[:, 2:4] 选择后两组,形状为 [B, 2, D, H * W]
        # 对后两组进行翻转操作,沿最后一个维度(H * W)翻转
        # 将翻转后的后两组与前两组相加,得到新的 ys,形状仍为 [B, 2, D, H * W]
        ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        
        # 进一步处理 ys 张量
        # ys[:, 0] 形状为 [B, D, H * W]
        # ys[:, 1] 形状为 [B, D, H * W]
        # 将 ys[:, 1] 重塑为 [B, D, W, H],然后转置为 [B, D, H, W]
        # 使用 contiguous() 保证内存连续性,再展平为 [B, D, H * W]
        # 将 ys[:, 0] 和处理后的 ys[:, 1] 相加,得到 y,形状为 [B, D, H * W]
        y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        
        # 返回处理后的张量 y,形状为 [B, D, H * W]
        return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        """
        反向传播方法,用于计算输入张量 ys 的梯度。
        
        参数:
            ctx: 上下文对象,包含前向传播中保存的信息。
            x: 上游的梯度,形状为 [B, D, H * W]
        
        返回:
            xs: 输入张量 ys 的梯度,形状为 [B, 4, D, H, W]
        """
        # 从上下文中恢复高度和宽度信息
        H, W = ctx.shape
        # 获取输入张量 x 的形状,批量大小 B,维度 D,展平长度 L = H * W
        B, C, L = x.shape
        
        # 创建一个新的空张量 xs,形状为 [B, 4, C, L]
        xs = x.new_empty((B, 4, C, L))
        
        # 将输入梯度 x 赋值给 xs 的第一个切片
        xs[:, 0] = x  # 对应前向传播中的 ys[:, 0]
        
        # 将输入梯度 x 重塑为 [B, C, H, W],然后转置为 [B, C, W, H],再展平为 [B, C, L]
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)  # 对应前向传播中的 ys[:, 1]
        
        # 对前两个切片进行翻转操作,得到 [B, 2, C, L]
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])  # 对应前向传播中的 ys[:, 2:4]
        
        # 将 xs 重塑回 [B, 4, C, H, W]
        xs = xs.view(B, 4, C, H, W)
        
        # 返回输入梯度 ys,形状为 [B, 4, C, H, W]
        return xs

举例说明,假设输入张量 ys 的形状为 [1, 4, 2, 2, 2],即 B=1, K=4, D=2, H=2, W=2

ys = torch.tensor([[
    [
        [[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]
    ],
    [
        [[9, 10],
         [11, 12]],
        [[13, 14],
         [15, 16]]
    ],
    [
        [[17, 18],
         [19, 20]],
        [[21, 22],
         [23, 24]]
    ],
    [
        [[25, 26],
         [27, 28]],
        [[29, 30],
         [31, 32]]
    ]
]], dtype=torch.float32)  # Shape: [1, 4, 2, 2, 2]

切片和翻转操作:

  • 选择前两组

    ys[:, 0:2] = [
        [
            [1, 2, 3, 4],
            [5, 6, 7, 8]
        ],
        [
            [9, 10, 11, 12],
            [13, 14, 15, 16]
        ]
    ]  # Shape: [1, 2, 2, 4]
    
  • 选择后两组并翻转

flipped = torch.flip(ys[:, 2:4], dims=[-1]) = [
    [
        [20, 19, 18, 17],
        [24, 23, 22, 21]
    ],
    [
        [28, 27, 26, 25],
        [32, 31, 30, 29]
    ]
]  # Shape: [1, 2, 2, 4]

  • 相加操作:
ys = ys[:, 0:2] + flipped = [
    [
        [1+20, 2+19, 3+18, 4+17],
        [5+24, 6+23, 7+22, 8+21]
    ],
    [
        [9+28, 10+27, 11+26, 12+25],
        [13+32, 14+31, 15+30, 16+29]
    ]
] = [
    [
        [21, 21, 21, 21],
        [29, 29, 29, 29]
    ],
    [
        [37, 37, 37, 37],
        [45, 45, 45, 45]
    ]
]  # Shape: [1, 2, 2, 4]

进一步处理:

  • 处理第一组: ys[:, 0] = [21, 21, 21, 21]

  • 处理第二组:

    ys[:, 1].view(1, -1, 2, 2) = [
        [
            [21, 21],
            [21, 21]
        ],
        [
            [37, 37],
            [37, 37]
        ],
        [
            [29, 29],
            [29, 29]
        ],
        [
            [45, 45],
            [45, 45]
        ]
    ]  # Shape: [1, 2, 2, 2]
    
  • 转置: 将高度和宽度维度交换,得到 [B, D, H, W] 形状的张量:

ys[:, 1].view(1, -1, 2, 2).transpose(2, 3) = [
    [
        [21, 21],
        [21, 21]
    ],
    [
        [37, 37],
        [37, 37]
    ],
    [
        [29, 29],
        [29, 29]
    ],
    [
        [45, 45],
        [45, 45]
    ]
]  # Shape: [1, 2, 2, 2]
  • 展平: 展平为 [B, D, H * W],即 [1, 2, 4]
y = [21+21, 21+21, 21+21, 21+21] + [37+37, 37+37, 37+37, 37+37]
    = [42, 42, 42, 42] + [74, 74, 74, 74]
    = [116, 116, 116, 116]
# Shape: [1, 2, 4]

最终前向输出:

y = torch.tensor([[
    [116, 116, 116, 116],
    [148, 148, 148, 148]
]], dtype=torch.float32)  # Shape: [1, 2, 4]

CrossScan_1和CrossMerge_1

这两个是一组,可以看成是CrossScan的第一分支,包含的就是原始数据,没有进行转置和翻转

class CrossScan_1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        # 获取输入张量 x 的形状,分别为批量大小 B、通道数 C、高度 H 和宽度 W
        B, C, H, W = x.shape
        # 将输入张量的形状信息保存到上下文 ctx 中,以便在反向传播时使用
        ctx.shape = (B, C, H, W)

        # 创建一个新的空张量 xs,形状为 [B, 1, C, H * W]
        xs = x.new_empty((B, 1, C, H * W))

        # 将输入张量 x 在高度和宽度维度上展平,赋值给 xs 的第一个切片
        xs[:, 0] = x.flatten(2, 3)

        # 返回处理后的张量 xs,形状为 [B, 1, C, H * W]
        return xs

    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        """
        反向传播方法,用于计算输入张量 x 的梯度。

        参数:
            ctx: 上下文对象,包含前向传播中保存的信息。
            ys: 上游的梯度,形状为 [B, 1, C, H * W]

        返回:
            y: 输入张量 x 的梯度,形状为 [B, C, H, W]
        """
        # 从上下文中恢复输入张量 x 的形状信息
        B, C, H, W = ctx.shape
        # 计算展平后的长度 L = H * W
        L = H * W

        # 从上游梯度 ys 中提取第一个切片,形状为 [B, C, H * W]
        y = ys[:, 0]

        # 将梯度 y 从 [B, C, H * W] 重塑为 [B, C, H, W]
        return y.view(B, -1, H, W)

class CrossMerge_1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        y = ys[:, 0]
        return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        xs = x.new_empty((B, 1, C, L))
        xs[:, 0] = x
        xs = xs.view(B, 1, C, H, W)
        return xs

CrossScan_2和CrossMerge_2

交换宽高后展平

class CrossScan_2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        xs = x.new_empty((B, 1, C, H * W))
        xs[:, 0] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        y = ys[:, 0].view(B, -1, H, W).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        return y.view(B, -1, H, W)

CrossScan_3

展平后翻转

class CrossScan_3(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        xs = x.new_empty((B, 1, C, H * W))
        xs[:, 0] = torch.flip(x.flatten(2, 3), dims=[-1])
        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        y = ys[:, 0].flip(dims=[-1]).view(B, 1, -1, L)
        return y.view(B, -1, H, W)

CrossScan_4

转置展平后翻转

class CrossScan_4(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        xs = x.new_empty((B, 1, C, H * W))
        xs[:, 0] = torch.flip(x.transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1])
        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        y = ys[:, 0].view(B, -1, H, W).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L).flip(dims=[-1]).view(B, 1, -1, L)
        return y.view(B, -1, H, W)

3.测试

原先的VMamba环境和Vim环境均报错,这里直接新建环境,省得麻烦,按照Groupmamba github作者的步骤:

conda create -n groupmamba python=3.10.13
conda activate groupmamba

安装Pytorch-GPU,实测pytorch-cuda=12.1会报错,所以安装11的

conda install cudatoolkit==11.8 -c nvidia

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118

conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

cd到源码包中安装其他依赖

cd groupmamba
pip install -r requirements.txt

cd到源码包中安装Mamba,貌似Vmamba和原始Mamba都不行,原先VMamba中提供的selective_scan是0.0.1版本,工作正常,现在这个版本是0.0.2。Groupmamba github给的方式是直接在源码目录中执行如下命令安装,可以安装正常,如果报错可以参考这个解决:https://github.com/MzeroMiko/VMamba/issues/216)

cd kernels/selective_scan && pip install .

在groupmamba下写测试代码:

   model = groupmamba_base()
    # print(model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    output = model(dummy_input)
    print(output[0].shape) # 分类头输出
    print(output[1].shape) # 蒸馏头输出

输出:

torch.Size([1, 1000])
torch.Size([1, 1000])
内容概要:本文详细探讨了基于樽海鞘算法(SSA)优化的极限学习机(ELM)在回归预测任务中的应用,并与传统的BP神经网络、广义回归神经网络(GRNN)以及未优化的ELM进行了性能对比。首先介绍了ELM的基本原理,即通过随机生成输入层与隐藏层之间的连接权重及阈值,仅需计算输出权重即可快速完成训练。接着阐述了SSA的工作机制,利用樽海鞘群体觅食行为优化ELM的输入权重和隐藏层阈值,从而提高模型性能。随后分别给出了BP、GRNN、ELM和SSA-ELM的具体实现代码,并通过波士顿房价数据集和其他工业数据集验证了各模型的表现。结果显示,SSA-ELM在预测精度方面显著优于其他三种方法,尽管其训练时间较长,但在实际应用中仍具有明显优势。 适合人群:对机器学习尤其是回归预测感兴趣的科研人员和技术开发者,特别是那些希望深入了解ELM及其优化方法的人。 使用场景及目标:适用于需要高效、高精度回归预测的应用场景,如金融建模、工业数据分析等。主要目标是提供一种更为有效的回归预测解决方案,尤其是在处理大规模数据集时能够保持较高的预测精度。 其他说明:文中提供了详细的代码示例和性能对比图表,帮助读者更好地理解和复现实验结果。同时提醒使用者注意SSA参数的选择对模型性能的影响,建议进行参数敏感性分析以获得最佳效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值