论文阅读:VMamba Visual State Space Model

文章简介

  • 本文主要设计了一种基于Mamba的计算高效的网络架构,将状态空间语言模型 Mamba 移植到 VMamba 中
  • 主要由中国科学院、华为、鹏城实验室的研究员提出

问题引入

  • 视觉特征学习是计算机视觉的基础研究领域,目前主要使用两种骨干网络,卷积神经网络(CNN)和基于Transformer架构的的ViT;
  • ViT的表征能力虽然强,但是计算复杂度是二次的。想寻找一种方法降低ViT的计算复杂度;
  • 作者想借鉴从自然语言处理(NLP)领域引入的Mamba方法,但是视觉数据具有非因果性质,直接将Mamba中的选择性扫描机制应用于图像会导致感受野受限。

设计思路

  • 在VSS块的基础上开发一系列VMamba架构,这种网络具有线性的时间复杂度;
  • 引入二维选择性扫描(SS2D),以弥补一维阵列扫描和二维平面遍历之间的差距,促进选择性 SSM 在处理视觉数据方面的扩展;
  • 提出通过新引入的交叉扫描模块(Cross-Scan Module,CSM)来解决方向敏感问题。CSM 不是以单向模式(列向或行向)遍历图像特征映射的空间域,而是采用四向扫描策略,即从特征映射的四个角到相对位置。这种策略确保特征映射中的每个元素从不同方向的所有其他位置集成信息,从而产生全局感受野,而不增加线性计算复杂性。

方法

Vmamba网络架构

我们在三个规模上开发了 VMamba:VMamba-Tiny、VMamba-Small 和 VMamba-Base(分别称为 VMamba-T、VMamba-S 和 VMamba-B)。图 3 (a) 显示了 VMamba-T 架构的概览,输入图像 I ∈ R H × W × 3 I ∈ R^{H×W ×3} IRH×W×3 首先由主干模块分割成块,得到空间维度为 H / 4 × W / 4 H/4 × W/4 H/4×W/4 的 2D 特征图。随后,采用多个网络阶段创建分辨率为 H / 8 × W / 8 、 H / 16 × W / 16 H/8 × W/8、H/16 × W/16 H/8×W/8H/16×W/16 H / 32 × W / 32 H/32 × W/32 H/32×W/32 的分层表示。每个阶段都包含一个下采样层(第一阶段除外),后面是一堆视觉状态空间 (VSS) 块。
请添加图片描述
VSS 块是 Mamba 块(图 3(b))的视觉对应物,用于表征学习。VSS 块的初始架构(图 3(c)中称为“原始 VSS 块”)由 S6 模块(Mamba 的核心,可同时实现全局感受野、动态权重(即选择性)和线性复杂度)替换为新提出的 2D 选择性扫描 (SS2D) 模块(将在下一小节中介绍)。为了进一步提高计算效率,我们消除了整个乘法分支(图 3(c)中红框圈出的部分),因为门控机制的效果已通过 SS2D 的选择性实现。因此,最终的 VSS 块(如图 3(d)所示)由一个带有两个残差模块的网络分支组成,模仿原始 Transformer 块的架构 [60]。在本文中,所有结果都是使用此架构中 VSS 块构建的 VMamba 模型获得的。

请添加图片描述

2D选择性扫描机制(SS2D)

VMamba将选择性扫描机制(S6)作为核心 SSM 运算符,但它以因果方式处理输入数据,因此只能捕获数据的扫描部分内的信息。这使 S6 在应用到非因果数据(例如图像、图形、集合等)时面临重大挑战。解决这个问题的一种直接方法是沿两个不同方向(即前向和后向)扫描数据,允许它们互相补偿而不增加计算复杂性。

尽管图像具有非因果性质,但还有一个与文本的不同之处在于它们包含 2D 空间信息(例如局部纹理和全局结构)。
虽然 S6 中扫描操作的顺序性与涉及时间数据的 NLP 任务非常吻合,但它在应用于视觉数据时带来了重大挑战,因为视觉数据本质上是非顺序的,并且包含空间信息(例如局部纹理和全局结构)。为了解决这个问题,作者使用选择性扫描方法进行输入处理,并提出 2D 选择性扫描 (SS2D) 模块,使 S6 适应视觉数据,而不会损害其优势。

如图 2 所示,SS2D 中的数据转发涉及三个步骤:交叉扫描、使用 S6 块进行选择性扫描和交叉合并。对于输入数据,SS2D 首先将输入的 patch 沿着四条不同的遍历路径展开为序列(即 Cross-Scan),然后使用单独的 S6 block 并行处理每个 patch 序列,随后对得到的序列进行 reshape 和合并以形成输出图(即 Cross-Merge)。通过采用互补的一维遍历路径,SS2D 使图像中的每个像素能够有效地整合来自不同方向的所有其他像素的信息,从而促进在二维空间中建立全局感受野。
请添加图片描述

加速 VMamba

如图 3 (e) 所示,带有 vanilla VSS 块的 VMamba-T 模型(称为“Vanilla VMamba”)实现了 426 张图像/秒的吞吐量,包含 22.9M 个参数和 5.6G FLOP。尽管分类准确率达到了最先进的 82.2%(在微小级别上,比 Swin-T高出 0.9%),但低吞吐量和高内存开销对 VMamba 的实际部署构成了重大挑战。在本小节中,我们概述了为提高其推理速度所做的努力,主要侧重于实现细节和架构设计的改进。我们通过 ImageNet-1K 上的图像分类来评估模型。每个渐进式改进的影响总结如下,其中 (%, img/s) 分别表示 ImageNet-1K 上的 top-1 准确率和推理吞吐量的增益。请添加图片描述
步骤 (a) (+0.0%,+41 图片/秒) 通过在 Triton 中重新实现 Cross-Scan 和 Cross-Merge。

步骤 (b) (+0.0%,-3 图片/秒) 通过调整选择性扫描的 CUDA 实现以适应 float16 输入和 float32 输出。尽管测试时速度略有波动,但这显著提高了训练效率(吞吐量从 165 到 184)。

步骤 © (+0.0%,+174 图片/秒) 通过用线性变换(即 torch.nn. functional.linear)替换选择性扫描中相对较慢的 einsum。我们还采用了 (B, C, H, W) 的张量布局来消除不必要的数据排列。

步骤 (d) (-0.6%,+175 图片/秒) 通过将 MLP 引入 VMamba 来实现,因为它的计算效率更高。我们还丢弃了 DWConv(深度卷积 [24])层,并将层配置从 [2,2,9,2] 更改为 [2,2,2,2],以降低 FLOP。

步骤 (e)(+0.6%,+366 图片/秒)通过丢弃整个乘法分支,采用图 3 (d) 中所示的 VSS 块,并将参数 ssm-ratio(特征扩展因子)从 2.0 降低到 1.0。这使我们能够将层数提高到 [2,2,5,2],同时降低 FLOP。

步骤 (f)(+0.3%,+161 图片/秒)通过将参数 d_state(SSM 状态维度)从 16.0 降低到 1.0。这使我们能够将 ssm-ratio 提高回 2.0 并引入 DWConv 层而不增加 FLOP。

步骤 (g) (+0.1%,+346 张图片/秒),将 ssm 比率降低到 1.0,同时将层配置从 [2,2,5,2] 更改为 [2,2,8,2]。

实验

在本节中,我们进行了一系列实验来评估 VMamba 的性能,并将其与各种视觉任务中的流行基准模型进行比较。我们还通过将所提出的 2D 特征图遍历方法与其他方法进行比较来验证其有效性。此外,我们通过可视化其有效感受野 (ERF) 和激活图来分析 VMamba 的特性,并检查其在越来越长的输入序列下的可扩展性。我们主要遵循 Swin 的超参数设置和实验配置。

所有实验均在具有 8 × NVIDIA Tesla-A100 GPU 的服务器

### VMamba云服务器使用说明 VMamba 是一种基于视觉状态空间模型(Visual State Space Model, VSSM)的技术框架,主要用于处理复杂的图像和视频数据建模任务。其核心组件包括 Mamba 和 VSS 块,这些模块通过高效的计算架构实现了高精度的分类性能以及较低的时间复杂度[^2]。 #### 环境准备 为了在云服务器上部署 VMamba,需确保满足以下依赖条件: 1. **操作系统**: 推荐 Ubuntu 22.04 LTS。 2. **Python 版本**: Python >= 3.8. 3. **GPU 支持**: NVIDIA GPU 配合 CUDA 工具链 (推荐版本为 CUDA 11.x)。 4. **软件包管理工具**: 使用 `conda` 或者 `mamba` 来加速环境配置过程。 5. **深度学习框架**: OpenAI 的 Triton 库支持高效张量操作[^1]。 以下是具体的安装步骤: ```bash # 更新系统包索引 sudo apt update && sudo apt upgrade -y # 安装必要的开发库 sudo apt install build-essential cmake git libboost-all-dev \ libeigen3-dev python3-pip python3-setuptools python3-wheel -y # 创建一个新的 conda 虚拟环境 mamba create -n vmamba_env python=3.9 -y mamba activate vmamba_env # 安装 PyTorch 及其他依赖项 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 # 克隆 VMamba GitHub 仓库 git clone https://github.com/author/vmamba.git cd vmamba # 安装项目特定需求 pip install -r requirements.txt # 测试安装是否成功 python test_vmamba.py ``` 如果遇到 causal_conv1d 或 mamba_ssm 组件无法正常安装的情况,可以尝试手动编译源码或者切换到预构建二进制文件的方式解决此问题。 #### 运行实例 假设已经完成上述准备工作,则可以通过如下命令启动训练脚本: ```bash # 替换 YOUR_CONFIG_PATH 为你实际使用的配置路径 python train.py --config YOUR_CONFIG_PATH ``` 对于推理阶段的操作流程也类似,只需加载已保存好的权重参数即可执行预测功能。 --- ### 注意事项 当运行过程中发现任何异常情况时,请仔细检查日志输出信息,并参照官方文档进一步排查原因。此外,在大规模分布式场景下应用该技术前还需充分评估硬件资源消耗状况以保障整体稳定性。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值