Mamba模型(Mmaba1, 没有mamba2)

问题

Transformer模型基于矩阵乘法,注意力机制中Q、K、V(N×d)会进行大量运算。

如:使用Q、K运算得到score的过程中,将Q每一行与K每一列进行运算,需要N^2次点积,每次点积包含d次乘法。即最终的复杂度为O(N^{2}d)。考虑到序列长度(L)、batch大小(B)、多层Transformer块的堆叠(M),计算量为M\cdot (24bLd^{2}+4bL^{2}d),参考

RNN存在梯度消失的问题,且训练时间长。

状态空间模型

Mamba中的状态空间模型(SSM)的灵感来源于传统的状态空间模型,公式描述如下:

其中,h ( t ) h(t)h(t)是当前的状态量,A AA是状态转移矩阵;x ( t ) x(t)x(t)为输入的控制量,B BB表示控制量对状态量的影响。y ( t ) y(t)y(t)表示系统的输出,C CC表示当前状态量对输出影响,D DD表示当前控制量对输出影响。

这个模型是时不变(time-invariant)系统,时不变系统是输出不会直接随着时间变化的系统。任意时间延迟的输入将得到相同时间延迟的输出(ABCD其实可以带上时间参数,那就是时变系统,延时输入会导致不同的输出)

离散化

 为了方便运算,将上述状态空间模型离散化,使用零阶保持法。 离散近似的结果为:

其中I为单位矩阵

 卷积化

 Mamba中的状态空间模型计算效率最大的提升就在于其序列运算可以卷积化。从0时刻开始向后推导状态空间模型几个时刻后的输出,即可得到如下的形式:

设计适当的卷积核即可将序列运算转化为卷积运算,形式如下所示:

 见解:设序列长度为x。使用高性能的卷积模式进行运算时会使用输入的所有信息预先计算K(0...x-1),最后直接与对应位置输入相乘即可。即这时Mamba是非因果的,每个点使用了未来的信息。

 选择性机制

传统状态空间模型的时序结构导致了其输出状态完全依赖有序的输入数据。对于增减过、顺序打乱过的输入,在一些不相关数据混杂在序列中出现时,状态空间模型就无法对其进行有效处理

Mamba针对这一情况进行了改进,在对B BB,C CC矩阵进行计算时,加入了选择性机制,即在计算是引入一个额外的线性层,对输入的输入的控制量和状态量进行选择,加强模型对不同输入形式的适应能力

 选择性机制对状态空间模型的改进

并行累加计算流程 

mamba结构

与Transformer结构类似,Mamba结构也是由若干Mamba块堆叠而成。一个基本的Mamba块结构如图所示:Mamba块由H3块以及门控MLP组合而成。H3为Hungry Hungry Hippos,是一种状态空间模型的执行方式。Mamba块简化了H3的结构,并与门控MLP结合,添加了残差项防止梯度消失。

使用mamba搭建模型一般会包含3个类似残差连接的部件:SSM内部的D,Mamba中除去SSM的另一条分支最后与SSM结果相乘,Mamba块外,即模型搭建时的残差连接

使用高效硬件乘法时先计算非离散化参数,送至GPU的SRAM(据我所知应该是share memory)中进行离散化及SSM运算,再回传到普通内存(HBM)。

此外,SSM正向传播并不会保存形状为(B, L, D, N)的中间状态以供反向传播训练,而是在反向传播时重新计算,与将其保存到HBM并读取相比,重新计算的耗时更短。

参考:

一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba_mamba模型-优快云博客

Mamba模型底层技术详解,与Transformer到底有何不同?-优快云博客

Mamba:2 状态空间模型 - 大模型知识库|大模型训练|开箱即用的企业大模型应用平台|智能体开发|53AI

 原文代码:mamba/README.md 在 main ·状态空格/曼巴 (github.com)

论文:[2312.00752] Mamba:使用选择性状态空间的线性时间序列建模 (arxiv.org)

参考代码:mamba-minimal/README.md at master · johnma2006/mamba-minimal (github.com) 

### Mamba 模型的技术原理 Mamba 是一种基于状态空间模型(State Space Model, SSM)的序列建模架构,其设计灵感来源于结构化状态空间模型的研究进展。与传统的递归神经网络(RNN)和Transformer不同,Mamba通过引入线性时不变(LTI)系统的概念,将序列建模问题转化为对状态空间的动态建模。在Mamba中,输入序列被映射到一个隐状态空间,并通过状态转移矩阵进行更新。这种设计允许模型以线性复杂度处理长序列,从而显著提升计算效率。 Mamba的核心技术包括: 1. **状态空间表示**:Mamba利用状态空间模型来捕捉序列数据的动态特性。其核心公式为: - 状态转移方程:$x_{t+1} = A x_t + B u_t$ - 输出方程:$y_t = C x_t + D u_t$ 其中,$x_t$是隐状态,$u_t$是输入,$y_t$是输出,而$A$、$B$、$C$、$D$是参数矩阵[^1]。 2. **选择机制**:Mamba引入了一个选择机制,通过可学习的参数调整状态转移矩阵$A$和输入矩阵$B$,使得模型能够动态地适应不同的输入序列。这一机制增强了模型的灵活性和表达能力。 3. **硬件感知优化**:Mamba的设计考虑了现代硬件的并行计算能力,通过优化矩阵运算和内存访问模式,实现了高效的训练和推理。 ### Mamba 的应用场景 Mamba模型因其高效的序列建模能力和灵活的设计,在多个领域展现出广泛的应用潜力: 1. **自然语言处理**:Mamba可以用于文本生成、机器翻译和文本摘要等任务。由于其线性复杂度的特点,Mamba在处理长文本时表现出色,尤其是在需要高效处理大规模数据的场景中。 2. **时间序列预测**:Mamba的状态空间模型天然适合时间序列建模,能够捕捉复杂的动态模式。它在金融预测、气象预测和工业监控等领域具有潜在的应用价值。 3. **语音识别与合成**:Mamba的线性复杂度和动态建模能力使其在语音处理任务中表现出色,尤其是在实时语音识别和高质量语音合成方面。 4. **生物信息学**:Mamba可以用于基因序列分析和蛋白质结构预测等任务,帮助研究人员更高效地处理大规模生物数据。 ### 相关论文与资源 1. **经典论文**: - **Mamba: A State Space Approach to Canonical Sequence Modeling**:这是Mamba模型的原始论文,详细介绍了其技术原理和设计思路。 - **Mamba-2: Enhancing the Efficiency and Flexibility of State Space Models**:该论文提出了Mamba-2,进一步优化了Mamba的性能和适用范围。 - **Structured State Space Models for Sequence Modeling**:这篇论文探讨了状态空间模型在序列建模中的应用,并为Mamba的设计提供了理论基础。 2. **开源实现与工具**: - **GitHub 项目**:https://github.com/Event-AHU/Mamba_State_Space_Model_Paper_List 提供了Mamba相关论文的列表以及实现代码。 - **PyTorch 实现**:许多研究者已经发布了基于PyTorch的Mamba实现,方便开发者快速上手和实验。 3. **社区与讨论**: - **学术会议**:Mamba相关的研究成果经常在顶级人工智能会议上展示,如NeurIPS、ICML和ICLR。 - **在线论坛**:Reddit和Stack Overflow等平台上有关于Mamba的讨论,用户可以分享经验或解决问题。 ### 示例代码 以下是一个简单的Mamba模型实现示例,展示了如何构建一个基本的状态空间模型: ```python import torch import torch.nn as nn class MambaModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MambaModel, self).__init__() self.A = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) self.B = nn.Parameter(torch.randn(hidden_dim, input_dim)) self.C = nn.Parameter(torch.randn(output_dim, hidden_dim)) self.D = nn.Parameter(torch.randn(output_dim, input_dim)) def forward(self, x): batch_size, seq_len, _ = x.shape hidden_state = torch.zeros(batch_size, self.A.shape[0], device=x.device) outputs = [] for t in range(seq_len): u_t = x[:, t, :] hidden_state = torch.matmul(hidden_state, self.A) + torch.matmul(u_t, self.B.t()) y_t = torch.matmul(hidden_state, self.C.t()) + torch.matmul(u_t, self.D.t()) outputs.append(y_t) return torch.stack(outputs, dim=1) # 示例用法 model = MambaModel(input_dim=10, hidden_dim=20, output_dim=5) input_data = torch.randn(32, 50, 10) # Batch size 32, sequence length 50, input dimension 10 output = model(input_data) print(output.shape) # 应该输出 (32, 50, 5) ```
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值