最近在看Mamba相关资料, 很多博客对状态空间表达(State space representation)的通解一笔带过, 这里做下推导
术语表
SSM: State Space Model
LTI: Linear Time-Invariant
1. State Space Model
a. SSM equation
表达形式
{
F
′
(
t
)
=
A
F
(
t
)
+
B
G
(
t
)
,
H
(
t
)
=
C
F
(
t
)
\left\{ \begin{array}{lr} \bm{F'(t) = AF(t) + BG(t),} \\ \bm{H(t) = CF(t)} \end{array} \right.
{F′(t)=AF(t)+BG(t),H(t)=CF(t)
状态空间以时间 t t t 为变量
A A A – 状态矩阵(已知)
B B B – 输入矩阵(已知)
C C C – 输出矩阵(已知)
G ( t ) G(t) G(t) – 输入向量(已知)
F ( t ) F(t) F(t) – 状态向量(未知)
H ( t ) H(t) H(t) – 输出向量
现求解 F ( t ) F(t) F(t)和 H ( t ) H(t) H(t)
首先变换
F
′
(
t
)
−
A
F
(
t
)
=
B
G
(
t
)
F'(t) - AF(t)=BG(t)
F′(t)−AF(t)=BG(t)
左右两边同乘
e
−
A
t
e^{-At}
e−At
e
−
A
t
F
′
(
t
)
−
A
e
−
A
t
F
(
t
)
=
e
−
A
t
B
G
(
t
)
e^{-At}F'(t) - Ae^{-At}F(t) = e^{-At}BG(t)
e−AtF′(t)−Ae−AtF(t)=e−AtBG(t)
左侧可以看做
e
−
A
t
F
(
t
)
e^{-At}F(t)
e−AtF(t)的一阶导数,即
d
d
t
[
e
−
A
t
F
(
t
)
]
=
e
−
A
t
B
G
(
t
)
\frac{d}{dt}[e^{-At}F(t)]=e^{-At}BG(t)
dtd[e−AtF(t)]=e−AtBG(t)
对左右两边同时求积分, 其中变量
τ
\tau
τ 从
t
t
t 变化到
t
0
t_0
t0
e
−
A
t
F
(
t
)
−
e
−
A
t
0
F
(
t
0
)
=
∫
t
0
t
e
−
A
τ
B
G
(
τ
)
d
τ
e^{-At}F(t)-e^{-At_0}F(t_0)=\int^{t}_{t_0}e^{-A\tau}BG(\tau)d\tau
e−AtF(t)−e−At0F(t0)=∫t0te−AτBG(τ)dτ
简化到
e
−
A
t
F
(
t
)
=
e
−
A
t
0
F
(
t
0
)
+
∫
t
0
t
e
−
A
τ
B
G
(
τ
)
d
τ
e^{-At}F(t) = e^{-At_0}F(t_0) + \int^{t}_{t_0}e^{-A\tau}BG(\tau)d\tau
e−AtF(t)=e−At0F(t0)+∫t0te−AτBG(τ)dτ
F
(
t
)
=
e
A
(
t
−
t
0
)
F
(
t
0
)
+
e
A
t
∫
t
0
t
e
−
A
τ
B
G
(
τ
)
d
τ
=
e
A
(
t
−
t
0
)
F
(
t
0
)
+
∫
t
0
t
e
A
(
t
−
τ
)
B
G
(
τ
)
d
τ
\begin{align*} F(t) & =e^{A(t-t_0)}F(t_0)+e^{At}\int^{t}_{t_0}e^{-A\tau}BG(\tau)d\tau \\ & =e^{A(t-t_0)}F(t_0)+\int^{t}_{t_0}e^{A(t-\tau)}BG(\tau)d\tau \end{align*}
F(t)=eA(t−t0)F(t0)+eAt∫t0te−AτBG(τ)dτ=eA(t−t0)F(t0)+∫t0teA(t−τ)BG(τ)dτ
终得
{
F
(
t
)
=
e
A
(
t
−
t
0
)
F
(
t
0
)
+
∫
t
0
t
e
A
(
t
−
τ
)
B
G
(
τ
)
d
τ
H
(
t
)
=
C
F
(
t
)
\begin{equation} \begin{align*} \left\{ \begin{array}{ll} F(t) &= e^{A(t-t_0)}F(t_0)+\int^{t}_{t_0}e^{A(t-\tau)}BG(\tau)d\tau \\ H(t) &= CF(t) \end{array} \right. \end{align*} \end{equation}
{F(t)H(t)=eA(t−t0)F(t0)+∫t0teA(t−τ)BG(τ)dτ=CF(t)
证毕
与Convolution的关联
我们回顾下卷积的表达形式, 例如存在一个卷积核 k ( s ) k(s) k(s) 对函数 I ( t ) I(t) I(t)进行卷积, 卷积后的函数为
C o v ( t ) = k ( s ) ∗ I ( t ) = ∫ t 0 t 1 k ( t − s ) I ( t ) d s Cov(t) = k(s) \ast I(t) = \int_{t0}^{t1}k(t-s)I(t)ds Cov(t)=k(s)∗I(t)=∫t0t1k(t−s)I(t)ds
离散化后的形式
C o n v ( t ) = ∑ t 0 t 1 k ( t − s ) I ( t ) Conv(t)=\sum_{t_0}^{t_1}k(t-s)I(t) Conv(t)=t0∑t1k(t−s)I(t)
假如 t 0 = 0 t_0=0 t0=0, t 1 = 2 t_1=2 t1=2, 则 t = 0 → n t=0\to n t=0→n 时, 离散化实例为
C o n v ( 0 ) = k ( 0 ) I ( 0 ) C o n v ( 1 ) = k ( 1 ) I ( 1 ) + k ( 0 ) I ( 1 ) C o n v ( 2 ) = k ( 2 ) I ( 2 ) + k ( 1 ) I ( 2 ) + k ( 0 ) I ( 2 ) C o n v ( 3 ) = k ( 3 ) I ( 3 ) + k ( 2 ) I ( 3 ) + k ( 1 ) I ( 3 ) . . . C o n v ( n ) = k ( n ) I ( n ) + k ( n − 1 ) I ( n ) + k ( n − 2 ) I ( n ) \begin{align*} Conv(0) & = k(0)I(0)\\ Conv(1) & = k(1)I(1) + k(0)I(1)\\ Conv(2) & = k(2)I(2) + k(1)I(2) + k(0)I(2)\\ Conv(3) & = k(3)I(3) + k(2)I(3) + k(1)I(3)\\ ...\\ Conv(n) & = k(n)I(n) + k(n-1)I(n) + k(n-2)I(n) \end{align*} Conv(0)Conv(1)Conv(2)Conv(3)...Conv(n)=k(0)I(0)=k(1)I(1)+k(0)I(1)=k(2)I(2)+k(1)I(2)+k(0)I(2)=k(3)I(3)+k(2)I(3)+k(1)I(3)=k(n)I(n)+k(n−1)I(n)+k(n−2)I(n)
我们再来看下state space equation ( 1 ) (1) (1) 的第二项
∫ t 0 t e A ( t − τ ) B G ( τ ) d τ \int^{t}_{t_0}e^{A(t-\tau)}BG(\tau)d\tau ∫t0teA(t−τ)BG(τ)dτ
本质上也是一个卷积, 卷积核为 B G ( τ ) BG(\tau) BG(τ), 输入函数为 e A t e^{At} eAt
b. Discretization
Mamba原论文中提到, 为了实际应用上述的SSM equation, 必须对原方程进行离散化表达, 即通过将连续的时间轴拆分成K个离散区间, 这时就需要使用 Zero-Order Hold(ZOH) 方法, 即假设函数值在
[
t
k
−
1
,
t
k
]
[t_{k-1}, t_k]
[tk−1,tk]之间为常量.
详细论证推理可以查看 ZOH原论文 这里不作赘述
经过ZOH离散化后的SSM表达形式
{
F
(
t
)
=
A
ˉ
F
(
t
−
1
)
+
B
ˉ
G
(
t
)
,
H
(
t
)
=
C
F
(
t
)
\left\{ \begin{array}{ll} \bm{F(t) = \bar{A}F(t-1) + \bar{B}G(t),} \\ \bm{H(t) = CF(t)} \end{array} \right.
{F(t)=AˉF(t−1)+BˉG(t),H(t)=CF(t)
其中
A
ˉ
=
e
A
Δ
(
t
)
,
B
ˉ
=
(
A
Δ
(
t
)
)
−
1
(
A
ˉ
−
I
)
⋅
B
Δ
(
t
)
\bar{A}=e^{A\Delta(t)}, \bar{B}=(A\Delta(t))^{-1}(\bar{A}-I)\cdot B\Delta(t)
Aˉ=eAΔ(t),Bˉ=(AΔ(t))−1(Aˉ−I)⋅BΔ(t)
与RNN的关联
我们回顾下RNN的表达形式
{ h t = σ ( U h t − 1 + W x t + b h ) y t = s o f t m a x ( V h t + b y ) \left\{ \begin{array}{ll} h_{t} = \sigma(Uh_{t-1} + Wx_{t} + b_h) \\ y_t = softmax(Vh_{t} + b_y) \end{array} \right. {ht=σ(Uht−1+Wxt+bh)yt=softmax(Vht+by)
σ \sigma σ – 激活函数
x t x_{t} xt – 输入向量
h t h_{t} ht – 隐状态 (hidden state)
W W W – 输入权重
U U U – 隐状态权重
V V V – 输出权重
b h b_h bh 和 b y b_y by – 偏置
可以看到, SSM中的 F ( t ) F(t) F(t) 可以看作是在一种没有激活函数的特殊形式的隐状态表达
然后, 我们对离散化SSM进行逐步计算
y
0
=
C
B
ˉ
x
0
y
1
=
C
A
ˉ
1
B
ˉ
x
0
+
C
B
ˉ
x
1
y
2
=
C
A
ˉ
2
B
ˉ
x
0
+
C
A
ˉ
1
B
ˉ
x
1
+
C
B
ˉ
x
2
.
.
.
y
k
=
C
A
ˉ
k
B
ˉ
x
0
+
C
A
ˉ
k
−
1
B
ˉ
x
1
+
.
.
.
+
C
A
ˉ
1
x
k
−
1
+
C
B
ˉ
x
k
\begin{align*} y_0 & = C\bar{B}x_{0} \\ y_1 & = C\bar{A}^1\bar{B}x_0 + C\bar{B}x_1\\ y_2 & = C\bar{A}^2\bar{B}x_0 + C\bar{A}^1\bar{B}x_1 + C\bar{B}x_2\\ ...\\ y_k & = C\bar{A}^k\bar{B}x_0 + C\bar{A}^{k-1}\bar{B}x_1 + ... + C\bar{A}^1x_{k-1} + C\bar{B}x_k \end{align*}
y0y1y2...yk=CBˉx0=CAˉ1Bˉx0+CBˉx1=CAˉ2Bˉx0+CAˉ1Bˉx1+CBˉx2=CAˉkBˉx0+CAˉk−1Bˉx1+...+CAˉ1xk−1+CBˉxk
然后上面一系列算子归纳为
K
ˉ
=
(
C
B
ˉ
,
.
.
.
,
C
A
ˉ
k
B
ˉ
,
.
.
.
)
\bar{K}=(C\bar{B}, ..., C\bar{A}^k\bar{B},...)
Kˉ=(CBˉ,...,CAˉkBˉ,...), 这种循环计算可以转化为到一种卷积的形式
y
=
x
∗
K
ˉ
y = x \ast \bar{K}
y=x∗Kˉ
其中
x
=
[
x
0
,
x
1
,
.
.
.
]
x=[x_0, x_1, ...]
x=[x0,x1,...]
y
=
[
y
0
,
y
1
,
.
.
.
]
y=[y_0, y_1, ...]
y=[y0,y1,...]
c. SSM 特点
传统SSM是 Time-invariant,即 A ˉ \bar{A} Aˉ 、 B ˉ \bar{B} Bˉ、 C C C和 Δ \Delta Δ 跟模型输入 x x x 无关。这也限制了模型上下文(context)的感知能力,导致传统SSM在Selective copying任务中表现不佳
2. Mamba
SSM在文本和信息密集型数据的任务中表现欠佳,为了赋予SSM类似Transformer的建模能力,Albert Gu and Tri Dao 提出了3种新的方法
a) Structure State Space(S3): 比如 High-order Polynomial Projection Operator(HiPPO)-based 记忆初始化
b) Selective Mechanism
c) Hardware-aware Computation
待续…
引用
- https://www.google.com/url?sa=t&source=web&rct=j&opi=89978449&url=https://orb.binghamton.edu/cgi/viewcontent.cgi%3Ffilename%3D7%26article%3D1002%26context%3Delectrical_fac%26type%3Dadditional&ved=2ahUKEwjX3I6k-M6KAxW_D0QIHTw3LJ04HhAWegQIDBAB&usg=AOvVaw1c83HJ0pZHBe9ow-lVrwY-
- https://zh.wikipedia.org/wiki/%E7%8A%B6%E6%80%81%E7%A9%BA%E9%97%B4