导读: 基于状态空间模型(State Space Model)的Mamba模型最近在深度学习领域有赶超Transformer的势头。其最主要的优势就在于其在长序列任务上的优异性能与较低的计算复杂度。本文就Mamba模型的原理进行解析,分析Mamba模型在结构上与Transformer的不同之处,以及其具有的应用潜力。©️【深蓝AI】
1. 状态空间模型
1.1 状态空间及其离散化
相信SLAM领域的小伙伴们对状态空间模型都并不陌生。在SLAM中,状态空间通常被用来描述所估计的状态参量,其通常具有观测以及预测模型。实际上Mamba中的状态空间模型的灵感恰恰来源于传统的状态空间模型,公式描述如下:
h′(t)=Ah(t)+Bx(t)y(t)=Ch(t)+Dx(t)\begin{aligned} & h^{\prime}(t)=A h(t)+B x(t) \\ & y(t)=C h(t)+D x(t)\end{aligned}h′(t)=Ah(t)+Bx(t)y(t)=Ch(t)+Dx(t)
其中,h(t)h(t)h(t)是当前的状态量,AAA是状态转移矩阵;x(t)x(t)x(t)为输入的控制量,BBB表示控制量对状态量的影响。y(t)y(t)y(t)表示系统的输出,CCC表示当前状态量对输出影响,DDD表示当前控制量对输出影响。
在SLAM领域,我们关心的是怎样使得状态量h(t)h(t)h(t)估计的最优,而在深度学习领域,我们更加关心系统的输出y(t)y(t)y(t)是否准确。通过不断优化AAA,BBB,CCC,DDD使得系统输出的y(t)y(t)y(t)接近理想的真实值。
为了方便运算,下一步需要将上述状态空间模型离散化,离散方法使用的是零阶保持法。实际上就是在状态空间常微分方程的解上做0阶近似,即假设解在采样周期内为常数。
微分方程的解如下:
h(t)=eA(tk+1−tk)h(t)+∫tktk+1eA(t−τ)Bx(τ)dτh(t)=e^{\mathbf{A}\left(t_{k+1}-t_k\right)} h\left(t\right)+\int_{t_k}^{t_{k+1}} e^{\mathbf{A}(t-\tau)} \mathbf{B} x(\tau) d \tauh(t)=eA(tk+1−tk)h(t)+∫tktk+1eA(t−τ)Bx(τ)dτ
零阶保持即假设h(t)h(t)h(t)在采样间隔时间段[tk,tk+1][t_k,t_{k+1}][tk,tk+1]为常数,设Δ=tk+1−tk\Delta=t_{k+1}-t_kΔ=tk+1−tk,得到离散近似的结果为:
hk=Aˉhk−1+Bˉxk,yk=Cˉhk,Aˉ=eΔA,Bˉ=(eΔA−I)A−1B,Cˉ=C\begin{aligned} h_k & =\bar{A} h_{k-1}+\bar{B} x_k, \\ y_k & =\bar{C} h_k, \\ \bar{A} & =e^{\Delta A}, \\ \bar{B} & =\left(e^{\Delta A}-I\right) A^{-1} B, \\ \bar{C} & =C\end{aligned}hkykAˉBˉCˉ=Aˉhk−1+Bˉxk,=Cˉhk,=eΔA,=(eΔA−I)A−1B,=C
注意这里将当前控制量对输出的影响DDD忽略不计。
由此,我们得到了Mamba模型的序列化表示结构。实际上这个结构与RNN类似,RNN的结构如下图1所示。与RNN不同的是,Mamba在计算输出y(t)y(t)y(t)时,直接采用了线性变换,而没有使用激活函数进行非线性化。而这一改变对Mamba模型的计算效率有着很重要的影响。
图1|传统RNN结构©️【深蓝AI】
1.2 卷积形式的状态预测
Mamba中的状态空间模型计算效率最大的提升就在于其序列运算可以卷积化,下面将对状态空间模型的序列卷积化进行详细介绍。从0时刻开始向后推导状态空间模型几个时刻后的输出,我们即可得到如下的形式:
h0=Bˉx0y0=Ch0=CBˉx0h1=Aˉh0+Bˉx1=AˉBˉx0+Bˉx1y1=Ch1=C(AˉBˉx0+Bˉx1)=CAˉBˉx0+CBˉx1h2=Aˉh1+Bˉx2=Aˉ(AˉBˉx0+Bˉx1)+Bˉx2=Aˉ2Bˉx0+AˉBˉx1+Bˉx2y2=Ch2=C(Aˉ2Bˉx0+AˉBˉx1+Bˉx2)=CAˉ2Bˉx0+CAˉBˉx1+CBˉx2yk=CAˉkBˉx0+CAˉk−1Bˉx1+⋯+CAˉBˉxk−1+CBˉxk\begin{aligned} &\begin{aligned} & h_0=\bar{B} x_0 \\ & y_0=C h_0=C \bar{B} x_0 \\ & h_1=\bar{A} h_0+\bar{B} x_1=\bar{A} \bar{B} x_0+\bar{B} x_1 \\ & y_1=C h_1=C\left(\bar{A} \bar{B} x_0+\bar{B} x_1\right)=C \bar{A} \bar{B} x_0+C \bar{B} x_1 \end{aligned}\\ &\begin{aligned} & h_2=\bar{A} h_1+\bar{B} x_2=\bar{A}\left(\bar{A} \bar{B} x_0+\bar{B} x_1\right)+\bar{B} x_2=\bar{A}^2 \bar{B} x_0+\bar{A} \bar{B} x_1+\bar{B} x_2 \\ & y_2=C h_2=C\left(\bar{A}^2 \bar{B} x_0+\bar{A} \bar{B} x_1+\bar{B} x_2\right)=C \bar{A}^2 \bar{B} x_0+C \bar{A} \bar{B} x_1+C \bar{B} x_2 \end{aligned}\\ &y_k=C \bar{A}^k \bar{B} x_0+C \bar{A}^{k-1} \bar{B} x_1+\cdots+C \bar{A} \bar{B} x_{k-1}+C \bar{B} x_k \end{aligned}h0=Bˉx0y0=Ch0=CBˉx0h1=Aˉh0+Bˉx1=AˉBˉx0+Bˉx1y1=Ch1=C(AˉBˉx0+Bˉx1)=CAˉBˉx0+CBˉx1h2=Aˉh1+Bˉx2=Aˉ(AˉBˉx0+Bˉx1)+Bˉx2=Aˉ2Bˉx0+AˉBˉx1+Bˉx2y2=Ch2=C(Aˉ2Bˉx0+AˉBˉx1+Bˉx2)=CAˉ2Bˉx0+CAˉBˉx1+CBˉx2yk=CAˉkBˉx0+CAˉk−1Bˉx1+⋯+CAˉBˉxk−1+CBˉxk左右滑动查看完整公式
此形式可以由卷积运算得到,设计适当的卷积核即可将序列运算转化为卷积运算,形式如下所示:
K‾=(CBˉ,CAB‾,…,CAˉkBˉ,…)y=x∗K‾\begin{aligned} \overline{{\mathbf{K}}} & =\left(C \bar{B}, C \overline{A B}, \ldots, C \bar{A}^k \bar{B}, \ldots\right) \\ y & = x * \overline{\mathbf{K}} \end{aligned}Ky=(CBˉ,CAB,…,CAˉkBˉ,…)=x∗K
如此一来,卷积运算在计算机中便可进行并行计算,这大大加速了状态空间模型的计算速度。而RNN由于在输出时使用了激活函数,因此无法进行卷积化,这也是状态空间模型对比RNN的一大优势所在。
图2|状态空间模型序列运算卷积化©️【深蓝AI】
2. 选择性扫描
传统状态空间模型的时序结构导致了其输出状态完全依赖有序的输入数据。一旦输入数据增减,或者顺序有所变化,那么状态空间模型就无法进行处理,如下图3所示,这种情况传统的状态空间模型完全可以处理,这是由于输出直接复制了输入的Token数据,非常符合状态空间模型的推理。
图3|传统状态空间模型可处理的情况©️【深蓝AI】
然而,如图4所示:对于增减过、顺序打乱过的输入,在一些不相关数据混杂在序列中出现时,状态空间模型就无法对其进行有效处理。
图4|传统状态空间模型难以处理的情况©️【深蓝AI】
Mamba针对这一情况进行了改进,在对BBB,CCC矩阵进行计算时,加入了选择性机制,即在计算是引入一个额外的线性层,对输入的输入的控制量和状态量进行选择,加强模型对不同输入形式的适应能力,算法流程如下图5所示。
图5|选择性机制对状态空间模型的改进©️【深蓝AI】
可以观察到,在引入选择性机制前,状态空间模型被认为是一个时不变系统,即以及不随时间变化;然而,在引入选择性机制后,Aˉ\bar AAˉ以及Bˉ\bar BBˉ随着筛选信号的改变而随时间变化,这是由于筛选机制使得窗口发生了变化,先前所成立的卷积化序列运算似乎又不再成立。为了加快运算,Mamba采用了多线程进行并行计算,对于每个序列利用结合律进行乱序计算,最后通过累加求取结果,流程如下图6所示。
图6|并行累加计算流程©️【深蓝AI】
3. Mamba块结构
与Transformer结构类似,Mamba结构也是由若干Mamba块堆叠而成。一个基本的Mamba块结构如图7所示:Mamba块由H3块以及门控MLP组合而成。H3为Hungry Hungry Hippos,是一种状态空间模型的执行方式。Mamba块简化了H3的结构,并与门控MLP结合,添加了残差项防止梯度消失。
图7|Mamba块结构©️【深蓝AI】
4. 应用潜力
Mamba的主要优势还是其优于Transformer的计算效率。
Mamba的网络结构对于GPU的计算来说十分友好,特别是在数据存取交互上,Mamba结构的数据交互主要集中在GPU何SRAM间,而这部分的数据交互是快速的。计算机内数据交互速度形式如下图8所示。
图8|计算机数据交互结构与速度©️【深蓝AI】
5. 总结
Mamba模型实际可以理解为改进版本的RNN,但是在计算上可以卷积化,进行并行训练,效率高,同时也处理了输入增改、顺序随机及梯度消失等问题。相较于Transformer,Mamba的计算复杂度低,同时也能够很好的保持长序列的关联能力。Mamba模型在CV等领域的扩展应用,在未来可能会井喷式出现。
参考资料:
https://github.com/hkproj/mamba-notes
笔者|Frank
审核|Los
移步公众号【深蓝AI】,第一时间获取自动驾驶、人工智能与机器人行业最新最前沿论文和科技动态。