GroupMamba解析
Shaker A, Wasim S T, Khan S, et al. GroupMamba: Parameter-Efficient and Accurate Group Visual State Space Model[J]. arXiv preprint arXiv:2407.13772, 2024.
1.论文解析
本文核心思想:VMamba四向扫描和SeNet注意力机制结合,结构参考Swin-Transformer方式。
初衷是解决基于状态空间模型(SSMs)在计算机视觉任务中的稳定性和效率问题。通过创新的调制组 Mamba 层、通道亲和调制算子和蒸馏训练目标,在多个视觉任务上取得了优异性能。
1.1 研究背景与动机
- 状态空间模型在处理长程依赖方面有潜力,但在计算机视觉任务中,基于 SSMs 的模型面临稳定性和性能优化挑战。例如 Mamba 模型在图像分类等任务中,当参数规模扩大时不稳定,且其 VSS 块在通道维度上的参数和计算复杂度较高。
- 卷积神经网络、视觉 Transformer 等方法在计算机视觉领域不断发展,但仍存在改进空间,促使研究人员探索新的模型结构。
1.2 相关工作
- 卷积神经网络(ConvNets)经历了多次架构演进,ConvNeXt 等变体在性能上不断提升。
- 视觉 Transformer(ViT)及其衍生模型对计算机视觉任务产生了重要影响,同时也引发了对注意力机制复杂度的研究,推动了高效变体的发展。
- 状态空间模型(SSMs)逐渐成为 ViT 的替代方案,如 S4、Mamba 等模型被提出,在视觉领域也有了多种应用,但仍需解决一些固有问题。
1.3 方法
- 整体架构:采用类似 Swin Transformer 的分层架构,分为四个阶段。输入图像先经过 Patch Embedding 层划分为不重叠的补丁并嵌入特征向量,然后依次通过多个调制组 Mamba 块和下采样层进行处理。
- 调制组 Mamba 层
- VSSS 块:基于 Mamba 算子的令牌和通道混合器,包含 Mamba 操作和前馈网络(FFN),对输入令牌序列进行处理。
- 分组 Mamba 操作:受分组卷积启发,将输入通道分为四组,每组应用独立的 VSSS 块,并在四个不同方向进行扫描,以更好地建模空间依赖关系,最后将结果拼接。(就是VMama,从代码上也是)
- 通道亲和调制(CAM):为解决分组操作导致的通道间信息交换有限问题,通过平均池化计算通道统计信息,经亲和计算后重新校准分组 Mamba 算子的输出。(本质就是SENET,换了个名字)
- 蒸馏损失函数:针对 Mamba 训练在大模型时不稳定的问题,采用蒸馏目标与标准交叉熵目标相结合的方式。通过最小化学生模型与教师模型(RegNetY - 16G)的分类损失和蒸馏损失之和来训练模型,稳定训练过程并提升性能。
1.4 实验结果
-
图像分类:在 ImageNet-1K 数据集上,GroupMamba-T、-S、-B 模型分别在不同参数规模下取得了较高的准确率,如 GroupMamba-T 以 2300 万参数和 4.5 GFLOPs 实现了 83.3% 的 top - 1 准确率,优于 ConvNeXt-T、Swin-T 等模型,且相比同类 SSM 模型 VMamba-T 参数减少 26%。
-
对象检测和实例分割:在 MS - COCO 2017 数据集上,基于 Mask-RCNN 框架,GroupMamb-T 模型的 box AP 为 47.6,mask AP 为 42.9,超越了 ResNet-50、Swin-T 等模型,且参数比 VMamba-T 少 20%。
-
语义分割:在 ADE20K 数据集上,GroupMamba-T 模型在单尺度和多尺度评估下的 mIoU 分别为 48.6 和 49.2,以 4900 万参数和 955G FLOPs 超越了 ResNet-50、Swin-T 等模型及相关 SSM 方法。
-
消融研究:实验表明 CAM 模块和蒸馏损失函数均对模型性能有提升作用,如 GroupMamba - T 在加入 CAM 模块后准确率提升 0.3%,加入蒸馏损失函数且不增加通道数时,相比 VMamba - T 性能提升 0.8% 且参数减少 26%。
2.源码解析
2.1 GroupMamba.py
核心代码在models/groupmamba.py
FFN类
就是一个普通的FC-GeLu-FC模块
class FFN(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
PVT2FFN
就是一个FC-分组卷积-GELU-FC模块
class PVT2FFN(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.fc2(x)
return x
GroupMambaLayer
这是论文的核心模块,实现的就是下图,类似于SE和Mamba结合。输入进来后分为三个并行分支,每个分支输入相同,分支1输入沿通道切分为四份,分别经过不同的Mamba沿着不同方向扫描处理后合并;分支2就是SE Net;分支1和分之2完成注意力加权;分支3是残差连接。
class GroupMambaLayer(nn.Module):
def __init__(self, input_dim, output_dim, d_state=1, d_conv=3, expand=1, reduction=16):
super().__init__()
# 计算降维后的通道数
num_channels_reduced = input_dim // reduction
# 定义第一个全连接层,将输入维度映射到降维后的通道数,包含偏置
self.fc1 = nn.Linear(input_dim, num_channels_reduced, bias=True)
# 定义第二个全连接层,将降维后的通道数映射回输出维度,包含偏置
self.fc2 = nn.Linear(num_channels_reduced, output_dim, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
# 保存输入和输出维度
self.input_dim = input_dim
self.output_dim = output_dim
# 定义层归一化,应用于输入维度
self.norm = nn.LayerNorm(input_dim)
# 定义四个 SS2D 模块,每个模块维度为输入维度的四分之一
self.mamba_g1 = SS2D(
d_model=input_dim // 4,
d_state=d_state,
ssm_ratio=expand,
d_conv=d_conv
)
self.mamba_g2 = SS2D(
d_model=input_dim // 4,
d_state=d_state,
ssm_ratio=expand,
d_conv=d_conv
)
self.mamba_g3 = SS2D(
d_model=input_dim // 4,
d_state=d_state,
ssm_ratio=expand,
d_conv=d_conv
)
self.mamba_g4 = SS2D(
d_model=input_dim // 4,
d_state=d_state,
ssm_ratio=expand,
d_conv=d_conv
)
# 定义投影层,将输入维度映射到输出维度
self.proj = nn.Linear(input_dim, output_dim)
# 定义一个可训练的跳跃连接缩放参数,初始值为1
self.skip_scale = nn.Parameter(torch.ones(1))
def forward(self, x, H, W):
# 如果输入张量的数据类型是 float16,则转换为 float32
if x.dtype == torch.float16:
x = x.type(torch.float32)
# 获取输入张量的批量大小 B、序列长度 N 和通道数 C
B, N, C = x.shape
# 对输入张量进行层归一化
x = self.norm(x)
# Channel Affinity
# [B, N, C] -> [B, C, N]->[B,C]
# 将输入张量 [B, N, C] 转换为 [B, C, N],
# 然后在序列维度上取均值,得到 [B, C]
z = x.permute(0, 2, 1).mean(dim=2)
# [B, C]->[B,output_dim]
# 通过第一个全连接层,并应用 ReLU 激活,得到 [B, output_dim]
fc_out_1 = self.relu(self.fc1(z))
# 通过第二个全连接层,并应用 Sigmoid 激活,得到 [B, output_dim]
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
# 将输入张量重新排列为 [B, H, W, C]
x = rearrange(x, 'b (h w) c -> b h w c', b=B, h=H, w=W, c=C)
# 将通道维度均分为四部分,分别赋值给 x1, x2, x3, x4
x1, x2, x3, x4 = torch.chunk(x, 4, dim=-1)
# 对四个不同方向应用四个 SS2D 扫描,每个扫描应用于 N/4 个通道
x_mamba1 = self.mamba_g1(x1, CrossScan=CrossScan_1, CrossMerge=CrossMerge_1)
x_mamba2 = self.mamba_g2(x2, CrossScan=CrossScan_2, CrossMerge=CrossMerge_2)
x_mamba3 = self.mamba_g3(x3, CrossScan=CrossScan_3, CrossMerge=CrossMerge_3)
x_mamba4 = self.mamba_g4(x4, CrossScan=CrossScan_4, CrossMerge=CrossMerge_4)
# 将所有特征图在通道维度上拼接,并乘以跳跃连接缩放参数和原始输入 x
x_mamba = torch.cat([x_mamba1, x_mamba2, x_mamba3, x_mamba4], dim=-1) * self.skip_scale * x
# 将拼接后的特征图重新排列回 [B, N, C]
x_mamba = rearrange(x_mamba, 'b h w c -> b (h w) c', b=B, h=H, w=W, c=C)
# Channel Modulation
# x_mamba:[B,N,C] fc_out_2.unsqueeze:[B,1,C]
# 将 x_mamba 与 fc_out_2 在通道维度上相乘,
# fc_out_2 通过 unsqueeze 扩展维度后变为 [B, 1, C]
x_mamba = x_mamba * fc_out_2.unsqueeze(1)
# 对调制后的特征图进行层归一化
x_mamba = self.norm(x_mamba)
# 通过投影层,将特征图映射到输出维度
x_mamba = self.proj(x_mamba)
# 返回最终输出
return x_mamba
上面代码中的CrossScan_1, CrossScan_2, CrossScan_3, CrossScan_4,CrossMerge_1, CrossMerge_2, CrossMerge_3, CrossMerge_4在csms6s.py中定义,见2.2
ClassBlock
类别token处理
class ClassBlock(nn.Module):
def __init__(self, dim, mlp_ratio, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = GroupMambaLayer(dim, dim)
self.mlp = FFN(dim, int(dim * mlp_ratio))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
# X:[B, N, C] x[:, :1]:提取出每个样本的第一个 Token
# cls_embed:[B,1,C]
# 获取类别嵌入,提取输入张量的第一个 token
cls_embed = x[:, :1]
# 通过注意力层处理分类嵌入,并与原始分类嵌入相加
cls_embed = cls_embed + self.norm1(self.attn(x[:, :1], H, W))
cls_embed = cls_embed + self.mlp(self.norm2(cls_embed), H, W)
return torch.cat([cls_embed, x[:, 1:]], dim=1)
cls_embed
的作用:在许多基于 Transformer 的模型(例如 BERT、ViT 等)中,第一个 Token 通常被设计为分类 Token(Classification Token),简称 cls
。这个 cls
Token 的输出会被用作整个序列的表示,用于下游的分类任务。通过 cls_embed = x[:, :1]
,我们提取出这个分类 Token,以便后续进行处理(如通过注意力层和前馈网络进行特征增强)。
在 ClassBlock
的 forward
方法中,cls_embed
被进一步处理,通过归一化层 self.norm1
和注意力层 self.attn
对 cls_embed
进行处理,并将处理后的结果与原始的 cls_embed
相加,实现特征的增强和融合。最终,将处理后的 cls_embed
与输入张量中除第一个 Token 以外的其他 Token 拼接在一起,形成最终的输出。
Block_mamba
常规Mamba块
class Block_mamba(nn.Module):
def __init__(self,
dim,
mlp_ratio,
drop_path=0.,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm2 = norm_layer(dim)
self.attn = GroupMambaLayer(dim, dim)
self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
# 将输入张量 x 通过注意力层处理,并应用 DropPath 后与原始输入相加
x = x + self.drop_path(self.attn(x, H, W))
# 对加和后的结果进行归一化,然后通过前馈网络处理,并应用 DropPath 后与原始输入相加
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
DownSamples
使用卷积完成特征图尺寸减半,并调整通道顺序,使其符合Mamba [B,N,C],通道在最后的输入要求。
class DownSamples(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.norm = nn.LayerNorm(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
# 通过卷积层进行下采样,输出特征图大小减半
x = self.proj(x)
# 获取下采样后特征图的高度和宽度
_, _, H, W = x.shape
# 将特征图展平,从第三维开始展平,并转置,使得形状为 [B, H*W, C]
x = x.flatten(2).transpose(1, 2)
# 对展平后的特征进行层归一化
x = self.norm(x)
# 返回归一化后的特征以及高度和宽度
return x, H, W
Stem
完成初步特征提取,对输入进行4倍下采样,并调整shape为[B,N,C],符合Mamba输入
class Stem(nn.Module):
def __init__(self, in_channels, stem_hidden_dim, out_channels):
super().__init__()
hidden_dim = stem_hidden_dim
self.conv = nn.Sequential(
# 第一个卷积层,使用 7x7 卷积核,实现2倍下采样
nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
padding=3, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
# 第二个卷积层,使用 3x3 卷积核,不改变尺寸
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
# 第三个卷积层,使用 3x3 卷积核,不改变尺寸
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
)
# 2倍下采样
self.proj = nn.Conv2d(hidden_dim,
out_channels,
kernel_size=3,
stride=2,
padding=1)
self.norm = nn.LayerNorm(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
# 通过卷积层序列进行特征提取和初步处理,尺寸减半
x = self.conv(x)
# 通过卷积层进行下采样,输出特征图大小减半,输出通道数转换为 out_channels
x = self.proj(x)
_, _, H, W = x.shape
# 将特征图展平,从第三维开始展平,并转置,使得形状为 [B, H*W, C]
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
GroupMamba
这个是核心,组合了前面所有
class GroupMamba(nn.Module):
def __init__(self,
in_chans=3,
num_classes=1000,
stem_hidden_dim = 32,
embed_dims=[64, 128, 348, 448],
mlp_ratios=[8, 8, 4, 4],
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], # 每个阶段的深度(Block 的数量)
num_stages=4, # 阶段数量
distillation=True,
**kwargs
):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
# 生成 DropPath 的概率列表,线性递增
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0 # 初始化当前深度索引
for i in range(num_stages): # 遍历每个阶段
# 如果是第一个阶段,使用 Stem 模块进行特征提取
if i == 0: # [3,224,224] -> [64*64,64]
patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
else: # 否则,使用 DownSamples 模块进行下采样
patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])
# 为当前阶段创建 Block_mamba 模块列表
block = nn.ModuleList([Block_mamba(
dim = embed_dims[i],
mlp_ratio = mlp_ratios[i],
drop_path=dpr[cur + j],
norm_layer=norm_layer)
for j in range(depths[i])])
# 为当前阶段创建归一化层
norm = norm_layer(embed_dims[i])
# 更新当前深度索引
cur += depths[i]
# 动态添加 patch_embed 模块到类中
setattr(self, f"patch_embed{i + 1}", patch_embed)
# 动态添加 block 模块到类中
setattr(self, f"block{i + 1}", block)
# 动态添加归一化层到类中
setattr(self, f"norm{i + 1}", norm)
# 定义后处理层列表,这里仅包含 'ca' 一个后处理层
post_layers = ['ca']
self.post_network = nn.ModuleList([
# 创建后处理网络模块列表,使用 ClassBlock
ClassBlock(
dim = embed_dims[-1],
mlp_ratio = mlp_ratios[-1],
norm_layer=norm_layer)
for _ in range(len(post_layers))
])
# classification head
# 定义分类头,如果类别数量大于0,则使用全连接层,否则使用恒等映射
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
# distillation head
# 保存是否使用蒸馏
self.dist = distillation
if self.dist: # 如果使用蒸馏,定义蒸馏头cc
self.dist_head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
# 如果使用蒸馏,定义蒸馏头
self.apply(self._init_weights)
# 定义权重初始化的方法
def _init_weights(self, m):
# 如果模块是全连接层
if isinstance(m, nn.Linear):
# 使用截断正态分布初始化权重,标准差为 0.02
trunc_normal_(m.weight, std=.02)
# 如果全连接层有偏置,则将偏置初始化为 0
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
# 如果模块是层归一化
elif isinstance(m, nn.LayerNorm):
# 将偏置初始化为 0
nn.init.constant_(m.bias, 0)
# 将权重初始化为 1.0
nn.init.constant_(m.weight, 1.0)
# 如果模块是二维卷积层
elif isinstance(m, nn.Conv2d):
# 计算 fan_out,等于卷积核尺寸乘以输出通道数
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# 除以组数,得到平均 fan_out
fan_out //= m.groups
# 使用正态分布初始化权重,标准差为 sqrt(2.0 / fan_out)
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
# 如果卷积层有偏置,则将偏置初始化为 0
if m.bias is not None:
m.bias.data.zero_()
# 定义分类前向传播方法
def forward_cls(self, x, H, W):
# 计算类别 token 的嵌入,通过对序列维度取均值
cls_tokens = x.mean(dim=1, keepdim=True)
# 将类别嵌入与原始输入拼接在一起
x = torch.cat((cls_tokens, x), dim=1)
for block in self.post_network:
# 遍历后处理网络中的每个模块,进行处理
x = block(x, H, W)
return x
# 定义特征提取的前向传播方法
def forward_features(self, x):
B = x.shape[0] # 获取批量大小
# 遍历每个阶段,进行特征提取和处理
for i in range(self.num_stages):
# 获取当前阶段的 patch_embed 模块
patch_embed = getattr(self, f"patch_embed{i + 1}")
# 获取当前阶段的 block 模块列表
block = getattr(self, f"block{i + 1}")
# 通过 patch_embed 模块进行嵌入和下采样
x, H, W = patch_embed(x)
# 遍历当前阶段的每个 Block_mamba 模块,进行处理
for blk in block:
x = blk(x, H, W)
# 如果不是最后一个阶段,进行归一化和维度变换
if i != self.num_stages - 1:
# 获取当前阶段的归一化层
norm = getattr(self, f"norm{i + 1}")
# 对特征进行归一化
x = norm(x)
# 重塑特征图形状,并进行转置,使其适应后续的卷积层
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# 通过分类前向传播方法处理特征,并提取第一个 token
# cls_tokens :[B, 1, C] 与原始特征x在dim=1拼接,得到[B, N+1, C]
# self.forward_cls : [B, N+1, C]
# [:, 0]:提取第一个 token 的特征向量,形状为 [B, C]
x = self.forward_cls(x, 1, 1)[:, 0]
# 获取最后一个阶段的归一化层
norm = getattr(self, f"norm{self.num_stages}")
# 对特征进行归一化
x = norm(x)
# 返回归一化后的特征
return x
# 定义整体前向传播方法
def forward(self, x):
# 提取特征
x = self.forward_features(x)
if self.dist:# 如果使用蒸馏
# 通过分类头和蒸馏头获取输出
cls_out = self.head(x), self.dist_head(x)
# 如果不是训练模式,合并两个输出
if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2
else: # 如果不使用蒸馏,只通过分类头获取输出
cls_out = self.head(x)
return cls_out
DWConv
深度可分离卷积层
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
# 获取输入张量的批量大小 B、序列长度 N 和通道数 C
B, N, C = x.shape
# 转置张量,将通道维度移到第二维,并重塑为 [B, C, H, W]
x = x.transpose(1, 2).view(B, C, H, W)
# 通过深度可分离卷积层进行空间特征提取
x = self.dwconv(x)
# 将卷积后的张量展平,从第三维开始展平,并转置回 [B, N, C] 的形状
x = x.flatten(2).transpose(1, 2)
return x
后面是三个不同体量的模型版本
@register_model
def groupmamba_tiny(pretrained=False, **kwargs):
model = GroupMamba(
stem_hidden_dim = 32,
embed_dims = [64, 128, 348, 448],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, eps=1e-6),
depths = [3, 4, 9, 3],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def groupmamba_small(pretrained=False, **kwargs):
model = GroupMamba(
stem_hidden_dim = 64,
embed_dims = [64, 128, 348, 512],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, eps=1e-6),
depths = [3, 4, 16, 3],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def groupmamba_base(pretrained=False, **kwargs):
model = GroupMamba(
stem_hidden_dim = 64,
embed_dims = [96, 192, 424, 512],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, eps=1e-6),
depths = [3, 6, 21, 3],
**kwargs)
model.default_cfg = _cfg()
return model
2.2 csms6s.py
定义了VMamba: Visual State Space Model
CrossScan
实现很巧妙,通过翻转实现不同方向的扫描
class CrossScan(torch.autograd.Function):
# 自定义的自动微分函数,继承自 torch.autograd.Function,
# 用于实现特定的前向和反向传播操作
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
# 将输入张量的形状信息保存到上下文 ctx 中,以便在反向传播时使用
ctx.shape = (B, C, H, W)
# 创建一个新的空张量 xs,形状为 [B, 4, C, H * W]
xs = x.new_empty((B, 4, C, H * W))
# 将输入张量 x 在高度和宽度维度上展平,得到形状为 [B, C, H * W]
xs[:, 0] = x.flatten(2, 3)
# 将输入张量 x 在高度和宽度维度上进行转置,然后展平,得到形状为 [B, C, H * W]
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
# 对前两个切片进行翻转,得到形状为 [B, 2, C, H * W]
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
"""
反向传播方法,用于计算输入张量 x 的梯度。
参数:
ctx: 上下文对象,包含前向传播中保存的信息。
ys: 上游的梯度,形状为 [B, 4, C, H * W]
返回:
y: 输入张量 x 的梯度,形状为 [B, C, H, W]
"""
# 从上下文中恢复输入张量 x 的形状信息
B, C, H, W = ctx.shape
# 计算 H * W 的值
L = H * W
# 将上游梯度 ys 切分为前两个和后两个部分,并进行翻转和求和
# ys[:, 0:2] 的形状为 [B, 2, C, H * W]
# ys[:, 2:4] 先沿最后一个维度翻转,然后与 ys[:, 0:2] 相加
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
# 将处理后的梯度 ys 切分为两部分
# ys[:, 0] 的形状为 [B, C, H * W]
# ys[:, 1] 先重塑为 [B, C, W, H],然后转置为 [B, C, H, W],最后展平为 [B, C, H * W]
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
# 将梯度 y 从 [B, C, H * W] 重塑为 [B, C, H, W]
return y.view(B, -1, H, W)
核心就是:
xs[:, 0] = x
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
xs = xs.view(B, 4, C, H, W)
举例说明:
假设输入张量 x
如下(简化为 1 个样本,1 个通道,2x2 图像)
x = [
[
[[a, b],
[c, d]]
]
] # Shape: [1, 1, 2, 2]
经过前向传播后的 xs
各个切片如下:
原始展平特征 (xs[:, 0]
):
[
[
[a, b, c, d]
]
] # Shape: [1, 1, 4]
转置后展平特征 (xs[:, 1]):
[
[
[a, c, b, d]
]
] # Shape: [1, 1, 4]
原始展平特征的翻转 (xs[:, 2]):
[
[
[d, c, b, a]
]
] # Shape: [1, 1, 4]
转置后展平特征的翻转 (xs[:, 3]):
[
[
[d, b, c, a]
]
] # Shape: [1, 1, 4]
即实现的就是如下图的四种不同方向的扫描:

CrossMerge
用于特征融合
前向传播:
- 将输入张量 ys进行四种不同的特征变换:
- 原始展平: 直接展平高度和宽度维度。
- 转置后展平: 先转置高度和宽度,再展平。
- 翻转原始展平: 对原始展平结果进行翻转。
- 翻转转置展平: 对转置后展平结果进行翻转。
- 通过相加操作融合这些变换后的特征,得到最终的输出
y
。
反向传播:
- 将上游梯度
x
进行逆向操作,恢复输入张量ys
的梯度,确保梯度的正确传播。
class CrossMerge(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
# 获取输入张量 ys 的形状,分别为批量大小 B、组数 K、维度 D、高度 H 和宽度 W
B, K, D, H, W = ys.shape
# 将高度和宽度信息保存到上下文 ctx 中,以便在反向传播时使用
ctx.shape = (H, W)
# 重塑 ys 张量的形状,将高度和宽度展平为一个维度,得到形状 [B, K, D, H * W]
ys = ys.view(B, K, D, -1)
# 对 ys 张量进行切片和翻转操作
# ys[:, 0:2] 选择前两组,形状为 [B, 2, D, H * W]
# ys[:, 2:4] 选择后两组,形状为 [B, 2, D, H * W]
# 对后两组进行翻转操作,沿最后一个维度(H * W)翻转
# 将翻转后的后两组与前两组相加,得到新的 ys,形状仍为 [B, 2, D, H * W]
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
# 进一步处理 ys 张量
# ys[:, 0] 形状为 [B, D, H * W]
# ys[:, 1] 形状为 [B, D, H * W]
# 将 ys[:, 1] 重塑为 [B, D, W, H],然后转置为 [B, D, H, W]
# 使用 contiguous() 保证内存连续性,再展平为 [B, D, H * W]
# 将 ys[:, 0] 和处理后的 ys[:, 1] 相加,得到 y,形状为 [B, D, H * W]
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
# 返回处理后的张量 y,形状为 [B, D, H * W]
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
"""
反向传播方法,用于计算输入张量 ys 的梯度。
参数:
ctx: 上下文对象,包含前向传播中保存的信息。
x: 上游的梯度,形状为 [B, D, H * W]
返回:
xs: 输入张量 ys 的梯度,形状为 [B, 4, D, H, W]
"""
# 从上下文中恢复高度和宽度信息
H, W = ctx.shape
# 获取输入张量 x 的形状,批量大小 B,维度 D,展平长度 L = H * W
B, C, L = x.shape
# 创建一个新的空张量 xs,形状为 [B, 4, C, L]
xs = x.new_empty((B, 4, C, L))
# 将输入梯度 x 赋值给 xs 的第一个切片
xs[:, 0] = x # 对应前向传播中的 ys[:, 0]
# 将输入梯度 x 重塑为 [B, C, H, W],然后转置为 [B, C, W, H],再展平为 [B, C, L]
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) # 对应前向传播中的 ys[:, 1]
# 对前两个切片进行翻转操作,得到 [B, 2, C, L]
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # 对应前向传播中的 ys[:, 2:4]
# 将 xs 重塑回 [B, 4, C, H, W]
xs = xs.view(B, 4, C, H, W)
# 返回输入梯度 ys,形状为 [B, 4, C, H, W]
return xs
举例说明,假设输入张量 ys
的形状为 [1, 4, 2, 2, 2]
,即 B=1
, K=4
, D=2
, H=2
, W=2
:
ys = torch.tensor([[
[
[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]
],
[
[[9, 10],
[11, 12]],
[[13, 14],
[15, 16]]
],
[
[[17, 18],
[19, 20]],
[[21, 22],
[23, 24]]
],
[
[[25, 26],
[27, 28]],
[[29, 30],
[31, 32]]
]
]], dtype=torch.float32) # Shape: [1, 4, 2, 2, 2]
切片和翻转操作:
-
选择前两组
ys[:, 0:2] = [ [ [1, 2, 3, 4], [5, 6, 7, 8] ], [ [9, 10, 11, 12], [13, 14, 15, 16] ] ] # Shape: [1, 2, 2, 4]
-
选择后两组并翻转
flipped = torch.flip(ys[:, 2:4], dims=[-1]) = [
[
[20, 19, 18, 17],
[24, 23, 22, 21]
],
[
[28, 27, 26, 25],
[32, 31, 30, 29]
]
] # Shape: [1, 2, 2, 4]
- 相加操作:
ys = ys[:, 0:2] + flipped = [
[
[1+20, 2+19, 3+18, 4+17],
[5+24, 6+23, 7+22, 8+21]
],
[
[9+28, 10+27, 11+26, 12+25],
[13+32, 14+31, 15+30, 16+29]
]
] = [
[
[21, 21, 21, 21],
[29, 29, 29, 29]
],
[
[37, 37, 37, 37],
[45, 45, 45, 45]
]
] # Shape: [1, 2, 2, 4]
进一步处理:
-
处理第一组: ys[:, 0] = [21, 21, 21, 21]
-
处理第二组:
ys[:, 1].view(1, -1, 2, 2) = [ [ [21, 21], [21, 21] ], [ [37, 37], [37, 37] ], [ [29, 29], [29, 29] ], [ [45, 45], [45, 45] ] ] # Shape: [1, 2, 2, 2]
-
转置: 将高度和宽度维度交换,得到
[B, D, H, W]
形状的张量:
ys[:, 1].view(1, -1, 2, 2).transpose(2, 3) = [
[
[21, 21],
[21, 21]
],
[
[37, 37],
[37, 37]
],
[
[29, 29],
[29, 29]
],
[
[45, 45],
[45, 45]
]
] # Shape: [1, 2, 2, 2]
- 展平: 展平为 [B, D, H * W],即 [1, 2, 4]
y = [21+21, 21+21, 21+21, 21+21] + [37+37, 37+37, 37+37, 37+37]
= [42, 42, 42, 42] + [74, 74, 74, 74]
= [116, 116, 116, 116]
# Shape: [1, 2, 4]
最终前向输出:
y = torch.tensor([[
[116, 116, 116, 116],
[148, 148, 148, 148]
]], dtype=torch.float32) # Shape: [1, 2, 4]
CrossScan_1和CrossMerge_1
这两个是一组,可以看成是CrossScan的第一分支,包含的就是原始数据,没有进行转置和翻转
class CrossScan_1(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
# 获取输入张量 x 的形状,分别为批量大小 B、通道数 C、高度 H 和宽度 W
B, C, H, W = x.shape
# 将输入张量的形状信息保存到上下文 ctx 中,以便在反向传播时使用
ctx.shape = (B, C, H, W)
# 创建一个新的空张量 xs,形状为 [B, 1, C, H * W]
xs = x.new_empty((B, 1, C, H * W))
# 将输入张量 x 在高度和宽度维度上展平,赋值给 xs 的第一个切片
xs[:, 0] = x.flatten(2, 3)
# 返回处理后的张量 xs,形状为 [B, 1, C, H * W]
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
"""
反向传播方法,用于计算输入张量 x 的梯度。
参数:
ctx: 上下文对象,包含前向传播中保存的信息。
ys: 上游的梯度,形状为 [B, 1, C, H * W]
返回:
y: 输入张量 x 的梯度,形状为 [B, C, H, W]
"""
# 从上下文中恢复输入张量 x 的形状信息
B, C, H, W = ctx.shape
# 计算展平后的长度 L = H * W
L = H * W
# 从上游梯度 ys 中提取第一个切片,形状为 [B, C, H * W]
y = ys[:, 0]
# 将梯度 y 从 [B, C, H * W] 重塑为 [B, C, H, W]
return y.view(B, -1, H, W)
class CrossMerge_1(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
B, K, D, H, W = ys.shape
ctx.shape = (H, W)
ys = ys.view(B, K, D, -1)
y = ys[:, 0]
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, l)
H, W = ctx.shape
B, C, L = x.shape
xs = x.new_empty((B, 1, C, L))
xs[:, 0] = x
xs = xs.view(B, 1, C, H, W)
return xs
CrossScan_2和CrossMerge_2
交换宽高后展平
class CrossScan_2(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
xs = x.new_empty((B, 1, C, H * W))
xs[:, 0] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
y = ys[:, 0].view(B, -1, H, W).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
return y.view(B, -1, H, W)
CrossScan_3
展平后翻转
class CrossScan_3(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
xs = x.new_empty((B, 1, C, H * W))
xs[:, 0] = torch.flip(x.flatten(2, 3), dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
y = ys[:, 0].flip(dims=[-1]).view(B, 1, -1, L)
return y.view(B, -1, H, W)
CrossScan_4
转置展平后翻转
class CrossScan_4(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
xs = x.new_empty((B, 1, C, H * W))
xs[:, 0] = torch.flip(x.transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
y = ys[:, 0].view(B, -1, H, W).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L).flip(dims=[-1]).view(B, 1, -1, L)
return y.view(B, -1, H, W)
3.测试
原先的VMamba环境和Vim环境均报错,这里直接新建环境,省得麻烦,按照Groupmamba github作者的步骤:
conda create -n groupmamba python=3.10.13
conda activate groupmamba
安装Pytorch-GPU,实测pytorch-cuda=12.1会报错,所以安装11的
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
cd到源码包中安装其他依赖
cd groupmamba
pip install -r requirements.txt
cd到源码包中安装Mamba,貌似Vmamba和原始Mamba都不行,原先VMamba中提供的selective_scan是0.0.1版本,工作正常,现在这个版本是0.0.2。Groupmamba github给的方式是直接在源码目录中执行如下命令安装,可以安装正常,如果报错可以参考这个解决:https://github.com/MzeroMiko/VMamba/issues/216)
cd kernels/selective_scan && pip install .
在groupmamba下写测试代码:
model = groupmamba_base()
# print(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dummy_input = torch.randn(1, 3, 224, 224).to(device)
output = model(dummy_input)
print(output[0].shape) # 分类头输出
print(output[1].shape) # 蒸馏头输出
输出:
torch.Size([1, 1000])
torch.Size([1, 1000])