初识Mamba

本文的主要内容: 

1,对一篇关于Mamba模型解读的阅读笔记

一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba(被誉为Mamba最佳解读)_mamba模型-优快云博客

1,对一篇关于Mamba科普文的阅读笔记 (涉及较多前置知识特别是ssm/S4的介绍)

A Visual Guide to Mamba and State Space Models - Maarten Grootendorst

2,对提出Mamba原始论文的阅读笔记

https://minjiazhang.github.io/courses/fall24-resource/Mamba.pdf

4,对mamba论文的一作Albert Gu在YouTube上关于S4论文的解读

https://www.youtube.com/watch?v=luCBXCErkCs

5,重温SSM(一):线性系统和HiPPO矩阵 - 科学空间|Scientific Spaces

6,github上mamba代码: 

https://github.com/state-spaces/mamba​​​​​​


Mamba,其基于SSM或S4发展为S6(S4 models with a selection mechanism and computed with a scan

SSM、HiPPO、S4->Mamba 

RNN的缺陷:

1,由于梯度消失和梯度爆炸的原因,ht一般只包含前面若干步而非之前所有步的隐藏状态。(短期记忆)

2,RNN没法并行训练(推理快但训练慢),因为其每个时间步的输出依赖于前一个时间步的输出

3,因为RNN的结构,导致其没法写成卷积形式(因为RNN多了一个非线性的转换函数,比如tanh)   

状态空间与SSM

状态向量(state vectors):描述状态的变量。 

状态空间模型(SSM): 用于描述状态表示并根据某些输入预测其下一个状态可能是什么的模型

一般SSMs包括以下组成

  • 输入序列x(t)(不使用离散序列(如向左移动一次),而是将连续序列作为输入并预测输出序列)
  • 潜在状态表示h(t),比如距离出口距离和 x/y 坐标
  • 预测输出序列y(t),比如再次向左移动以更快到达出口

SSM 假设动态系统可以通过两个方程从其在时间t 时的状态进行预测

 第一个方程与RNN中的非常类似,其中的A存储着之前所有历史信息的浓缩精华(可以通过一系列系数组成的矩阵表示),以基于A更新下一个时刻的空间状态hidden state

注意!A、B、C、D这4个矩阵是参数,是可以学习到的,但学习好之后,在SSM中,矩阵A、B、C、D便固定不变了——即便是在不同的输入之下,但到了后续的改进版本mamba中则这4个矩阵可以随着输入不同而可变 

SSM->离散化SSM、循环表示、卷积表示、HiPPO处理长序列->S4->对角化->S4D

1,SSM的离散化(离散的输入->连续化->连续的输入->SSM->连续的输出->离散化->离散的输出)

SSM处理连续输入的话就叫连续SSM,如果处理离散输入的话,就叫离散SSM。所以SSM的离散化就是让SSM能够处理离散的输入。我们让离散的输入连续化(使用零阶保持器),这样就获得了SSM可以使用的连续信号(从而得到离散SSM),有了连续的输入信号后,便可以生成连续的输出,并且仅根据输入的时间步长( 采用零阶保持器时,保持值的时间由一个新的可学习参数表示,称为步长(siz)——\Delta ),对连续输出值进行采样,从而得到离散的输出。

可以针对A、B按如下方式做零阶保持(做了零阶保持的在对应变量上面加了个横杠

 注意:我们在保存时,仍然保存矩阵A的连续形式(而非离散化版本),只是在训练过程中,连续表示被离散化(During training, the continuous representation is discretized)

we note that structured SSMs are so named because computing them efficiently also requires imposing structure on the 𝑨 matrix. The most popular form of structure is diagonal.

 2,循环结构表示:方便快速推理

 

3,卷积结构表示(将 SSM 表示为卷积使它可以像CNN一样进行并行训练)

从而把SSM表示为了卷积形式。由于其中三个离散参数A、B、C都是常数,因此我们可以预先计算左侧向量并将其保存为卷积核,这为我们提供了一种使用卷积超高速计算y的简单方法。然而,由于内核大小固定,它们的推理不如 RNN 那样快速。

所以在训练模式下-CNN,在推理模式下->RNN

4,解决长距离依赖问题:HiPPO

矩阵A 产生了隐藏状态(matrix A captures information from previous state to build new state),且矩阵A只记住之前的几个token和捕获迄今为止看到的每个token之间的区别。->长程依赖问题

如何解决这个问题呢?可以使用HiPPO(HiPPO 矩阵通常用于将连续时间信号投影到正交多项式基下,以代表过去的状态/信息),解决如何在有限的存储空间中有效地解决序列建模的长距离依赖问题。

HiPPO尝试将当前看到的所有输入信号压缩为一个系数向量,它使用矩阵A构建一个“可以很好地捕获最近的token并衰减旧的token”状态表示。

Building matrix A using HiPPO was shown to be much better than initializing it as a random matrix. As a result, it more accurately reconstructs newer signals (recent tokens) compared to older signals (initial tokens).

 正由于HiPPO 矩阵可以产生一个隐藏状态来记住其历史(从数学上讲,它是通过跟踪Legendre polynomial 的系数来实现的,这使得它能够逼近所有以前的历史),使得在被应用于循环表示和卷积表示中时,可以处理远程依赖性。

S4:一种可以处理长序列的SSM(序列的结构化状态空间Structured State Space for Sequences)

S4D: S4的对角版本

为了提高实际可行性,S4D将参数矩阵标准化为对角结构

如上图左侧所示,当基于HiPPO的A矩阵变换为对角线结构之后,便使得其可以被视为一组一维SSM
如上图右侧所示,作为卷积模型,S4D具有简单且可解释的卷积核,可以用两行代码实现
颜色表示独立的一维SSM;紫色表示可训练参数「 Colors denote independent1-D SSMs;purple denotes trainable parameters」

Mamba的创新:有选择处理信息 + 硬件感知算法 

1,对输入信息有选择性处理(Selection Mechanism)

(1)transformer(more powerful,less efficient)的注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(storing the entire context,也就是KV缓存),直接导致训练和推理消耗算力大

在训练的时候快,因为可以并行化。在推理的时候慢,因为Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢。
(2)RNN(more effecient,less powerful)的推理和训练效率高,但性能容易受到对上下文压缩程度的限制
On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time training. However, their effectiveness is limited by how well this state has compressed the context.

好比,RNN每次只参考前面固定的字数(When generating the output, the RNN only needs to consider the previous hidden state and current input. It prevents recalculating all previous hidden states which is what a Transformer would do),写的快是快,但容易忘掉更前面的内容

(3)CNN训练效率高,可并行「因为能够绕过状态计算,并实现仅包含(B, L, D)的卷积核

(4)SSM:Linear Time Invariance规定:推理时 SSM中的A、B、C不因输入不同而不同(Constant regardless of the input),使得SSM无法针对输入做针对性的推理。(有长程依赖问题)

S4(more effecient,less powerful) 引入HiPPO解决了长程依赖问题,压缩每一个历史记录,无法针对输入做针对性的推理,认为每个token的重要程度都相同。(有四个参数∆, ABC

(5)Mamba(more powerful and more effecient)

a simple selection mechanism by parameterizing the SSM parameters based on the input

虽然在推理时参数本身也不变,但由于其设计中引入的选择性机制,使得模型能够根据不同输入token的特点进行有区别的对待。即 Mamba可根据不同的输入数据动态计算矩阵B、C和步长Δ的值(推理过程中不会对模型的参数进行重新训练或调整,而是简单地应用训练阶段学到的方式——决定如何计算这些矩阵和步长的函数或映射,来生成预测) 。

与其相关的代码如下:

zxbcdt = self.in_proj(u)
z0, x0, z, xBC, dt = torch.split(
                zxbcdt,
                [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
                dim=-1
            )
 if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
                assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
                xBC = self.act(
                    self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)]
                )  # (B, L, self.d_ssm + 2 * ngroups * d_state)
            else:
                xBC = causal_conv1d_fn(
                    xBC.transpose(1, 2),
                    rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                    seq_idx=seq_idx,
                ).transpose(1, 2)
 x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)

可以大概这样理解:其中u是输入,zxbcdt与输入有关,把他分成五部分,其实的dt即是detla。把其中的xBC经过一个卷积之后分成三部分,其中两部分分别为B和C。

代码链接:

mamba/mamba_ssm/modules/mamba2.py at v2.2.4 · state-spaces/mamba

这样可以实现更关注于输入中对当前任务更重要的部分。但是这样的话我们无法预计算\overline{\mathbf{K}},也无法使用CNN模式来训练我们的模型。所以需要找到一种无需卷积的并行训练方式。

efficient models must have a small state, while effective models must have a state that contains all necessary information from the context.

  • 高效的模型必须有一个小的状态(比如RNN或S4)
  • 而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)

2,硬件感知的算法(Hardware-aware Algorithm)

这是为了解决速度问题。主要通过三种方法

1,kernel fusion 来减少amount of memory IOs

2,parallel scan 来实现并行化的计算

3,recomputation 在反向传播过程中重新计算h来减少loaded from HBM to SRAM的时间

1,

 借鉴flash attention ,限制需要从 DRAM 到 SRAM 的次数(通过内核融合kernel fusion来实现),避免一有个结果便从SRAM写入到DRAM,而是待SRAM中有一批结果再集中写入DRAM中,从而降低来回读写的次数

 最终在更高速的SRAM内存中执行离散化和递归操作,再将输出写回HBM,具体来说

不是在GPU HBM(高带宽内存)中将大小为(B,L,D,N)的扫描输入进(A,B),而是直接将SSM参数从慢速HBM加载到快速SRAM中

然后,在SRAM中进行离散化,得到(B,L,D,N)的
接着,在SRAM中进行scan(通过并行扫描算法实现并行化),得到(B,L,D,N)的输出ht
最后,multiply and sum with C,得到(B,L,D)的最终输出写回HBM

instead of preparing the scan input (A, B) of size (𝙱, 𝙻, 𝙳, 𝙽) in GPU HBM (high-bandwidth memory), we load the SSM parameters (Δ, 𝑨, 𝑩, 𝑪) directly from slow HBM to fast SRAM, perform the discretization and recurrence in SRAM, and then write the final outputs of size (B, L, D) back to HBM

2,该算法采用“并行扫描算法”而非“卷积”来进行模型的循环计算(使得不用CNN也能并行训练),但为了减少GPU内存层次结构中不同级别之间的IO访问,没有具体化扩展状态

扫描操作:每个状态比如H_1都是前一个状态比如H_0乘以\overline{\mathrm{A}},加上当前输入X_1乘以\overline{\mathrm{B}}的总和,这就叫扫描操作(scan operation),可以使用 for 循环轻松计算,然这种状态之下想并行化是不可能的(因为只有在获取到前一个状态的情况下才能计算当前的每个状态)

parallel scan algorithm(并行扫描算法)是一种允许在保持循环计算特性的同时,对序列数据进行并行处理的技术。这种方法可以在处理序列时,对序列的各个部分同时进行计算——而不是一个接一个地处理,从而实现并行化)

选择性扫描算法(selective scan algorithm):动态矩阵B和C以及并行扫描算法一起创建

Together, dynamic matrices B and C, and the parallel scan algorithm create the selective scan algorithm


 

3,注意,当输入从HBM加载到SRAM时,中间状态不被保存,而是在反向传播中重新计算
the intermediate states are not stored but recomputed in the backward pass when the inputs are loaded from HBM to SRAM

The intermediate states are not saved but are necessary for the backward pass to compute the gradients. Instead, the authors recompute those intermediate states during the backward pass. 

Although this might seem inefficient, it is much less costly than reading all those intermediate states from the relatively slow DRAM.

3,更简单的架构
将SSM架构的设计与transformer的MLP块合并为一个块(combining the design of prior SSM architectures with the MLP block of Transformers into a single block),来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计

H3:大多数SSM架构

Gated MLP:Transformer中普遍存在的Gated MLP

为何要做线性投影?
经过线性投影后,输入嵌入的维度可能会增加,以便让模型能够处理更高维度的特征空间,从而捕获更细致、更复杂的特征


为什么SSM前面有个卷积?

a convolution before the Selective SSM is applied to prevent independent token calculations.
本质是对数据做进一步的预处理,更细节的原因在于:
  SSM之前的CNN负责提取局部特征(因其擅长捕捉局部的短距离特征),而SSM则负责处理这些特征并捕捉序列数据中的长期依赖关系,两者算互为补充
  CNN有助于建立token之间的局部上下文关系,从而防止独立的token计算
毕竟如果每个 token 独立计算,那么模型就会丢失序列中 token 之间的上下文信息。通过先进行卷积操作,可以确保在进入 SSM 之前,序列中的每个 token 已经考虑了其邻居 token 的信息。这样,模型就不会单独地处理每个 token,而是在处理时考虑了整个局部上下文

The Selective SSM has the following properties:

  • Recurrent SSM created through discretization
  • HiPPO initialization on matrix A to capture long-range dependencies
  • Selective scan algorithm to selectively compress information
  • Hardware-aware algorithm to speed up computation

https://www.youtube.com/watch?v=luCBXCErkCs

 模型参数类ModelArgs、完整的Mamba模型类Mamba、残差块类ResidualBlock、单个Mamba块类MambaBlock、RMSNorm归一化类以及一些辅助函数

mamba
├── benchmarks //包含用于基准测试 Mamba 与其他模型性能的脚本。
│ 	└── benchmark_generation_mamba_simple.py  // 示例模型的推理脚本
├── csrc
│ 	└── selective_scan  // 选择性扫描的c++实现
├── evals //用于模型评估和下游任务的实验代码。
│ 	└── lm_harness_eval.py
├── mamba_ssm
│ 	├── models
│   │   ├── config_mamba.py
│   │   └── mixer_seq_simple.py  // 使用mamba构建的一个完整的语言模型示例
│ 	├── modules
│   │   └── mamba_simple.py   // mamba block的实现
│ 	├── ops
│   │   ├── triton
│   │   │   ├── layernorm.py
│   │   │   ├── selective_state_update.py
│   │   └── selective_scan_interface.py   // 选择性SSM层的实现
│ 	├── utils
│   │   ├── generation.py
│   │   └── hf.py
└── test  //单元测试代码,验证项目中的各个模块是否正常工作。
		└── ops
		    ├── triton
		    │		├── test_selective_state_update.py
        └──test_selective_scan.py


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值