目录
Mamba原理:参考链接1,参考链接2,参考链接3
Mamba代码:参考链接,mamba-mini实现参考链接官方mamba库链接: https://github.com/state-spaces/mamba
(注:以下以1.1.4版本mamba-ssm库的代码为例)🥳🥳🥳欢迎各位大佬一起来交流讨论🥳🥳🥳
1 代码原理简述
熟悉原理的话可以直接看第二章代码
原理介绍摘录参考了链接,(这位大佬的博客应该是全网最详细的了,想摸清原理一定要读一下)
Mamba算法的本质是借鉴了现代控制理论中的状态空间模型SSM。
1.1 原始结构——SSM
状态空间模型SSM的公式如下:
通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态。(A、B、C、D这4个矩阵是可学习的参数,是可以学习到的,学习好之后,在SSM中,矩阵A、B、C、D便固定不变了,但到后续mamba中则这4个矩阵可以随着输入不同而可变)
这一过程可以由如下图来表示:
1.2 结构改进——S4(Structured State Space for Sequences)
S4模型对SSM的改进有以下三点:
- 采用零阶保持,来进行连续化:由于SSM模型是针对连续函数的,但是在文本、图像等领域,数据都是离散的,所以我们需要将离散的点连续化,才能输入进SSM模型,最后再从连续的输出中采样离散的点来得到真正的输出
- 使用循环+卷积结构表示,从而能够并行训练,加快训练速度
- 使用HIPPO矩阵,以处理远程依赖
1.2.1 离散化
由于以上公式是针对连续变量的,单实际计算机相关的应用中往往离散的,所以经过采样和零阶保持的方法,便有了以下离散化公式(这个公式很重要,在代码中会体现):
然后公式变成了这样:
在每个时间步,都会涉及到隐藏状态的更新(比如
h
k
h_{k}
hk取决于
B
‾
x
k
\overline{\mathrm{B}}\mathrm{x}_{k}
Bxk和
A
‾
h
k
−
1
\overline{\mathrm{A}}\mathrm{h}_{k-1}
Ahk−1的共同作用结果,然后通过
C
h
k
Ch_k
Chk预测输出
y
k
y_k
yk),这样反复套娃,就得到了类似RNN的循环结构:
换个图表示更明显:
用公式来表示就是这样的:
y
2
=
C
h
2
=
C
(
A
ˉ
h
1
+
B
ˉ
x
2
)
=
C
(
A
ˉ
(
A
ˉ
h
0
+
B
ˉ
x
1
)
+
B
ˉ
x
2
)
=
C
(
A
ˉ
(
A
ˉ
⋅
B
ˉ
x
0
+
B
ˉ
x
1
)
+
B
ˉ
x
2
)
=
C
(
A
ˉ
⋅
A
ˉ
⋅
B
ˉ
x
0
+
A
ˉ
⋅
B
ˉ
x
1
+
B
ˉ
x
2
)
=
C
⋅
A
ˉ
2
⋅
B
ˉ
x
0
+
C
⋅
A
ˉ
⋅
B
ˉ
⋅
x
1
+
C
⋅
B
ˉ
x
2
\begin{aligned} y_2& =Ch_{2} \\ &=C\left(\bar{A}h_{1}+\bar{B}x_{2}\right) \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}h_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}\cdot\bar{B}x_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\cdot\bar{A}\cdot\bar{B}x_{0}+\bar{A}\cdot\bar{B}x_{1}+\bar{B}x_{2}\end{pmatrix} \\ &=C\cdot\bar{A}^{2}\cdot\bar{B}x_{0}+C\cdot\bar{A}\cdot\bar{B}\cdot x_{1}+C\cdot\bar{B}x_{2} \end{aligned}
y2=Ch2=C(Aˉh1+Bˉx2)=C(Aˉ(Aˉh0+Bˉx1)+Bˉx2)=C(Aˉ(Aˉ⋅Bˉx0+Bˉx1)+Bˉx2)=C(Aˉ⋅Aˉ⋅Bˉx0+Aˉ⋅Bˉx1+Bˉx2)=C⋅Aˉ2⋅Bˉx0+C⋅Aˉ⋅Bˉ⋅x1+C⋅Bˉx2
由此类推:
y
3
=
C
A
A
A
B
‾
x
0
+
C
A
A
B
‾
x
1
+
C
A
B
‾
x
2
+
C
B
‾
x
3
y_{3}=\mathbf{C\overline{AAAB}}x_{0}+\mathbf{C\overline{AAB}}x_{1}+\mathbf{C\overline{AB}}x_{2}+\mathbf{C\overline{B}}x_{3}
y3=CAAABx0+CAABx1+CABx2+CBx3
y
k
=
C
A
ˉ
k
B
ˉ
x
0
+
C
A
ˉ
k
−
1
B
ˉ
x
1
+
⋯
+
C
A
ˉ
B
ˉ
x
k
−
1
+
C
B
ˉ
x
k
y_{k}=C\bar{A}^{k}\bar{B}x_{0}+C\bar{A}^{k-1}\bar{B}x_{1}+\cdots+C\bar{A}\bar{B}x_{k-1}+C\bar{B}x_{k}
yk=CAˉkBˉx0+CAˉk−1Bˉx1+⋯+CAˉBˉxk−1+CBˉxk
这样看来循环结构就更明显了。
1.2.2HiPPO
HiPPO是一种解决长距离依赖问题的方法,用于构建并调整状态矩阵A使其逼近最优解,这种方法比把A初始化为随机矩阵要好得多
1.3 最终版本——Mamba(又称S6或selective SSMs)
Mamba模型对于S4模型的改进有以下三点:
- 参数化矩阵:对输入信息进行有选择性的处理,从而得到类似Attention的效果,即不同的输入拥有不同的状态,token信息
- 硬件感知算法,并行化–选择扫描算法,加快训练推理速度
- 更简化的SSM模型架构
在最终的Mamaba中,作者让B矩阵、C矩阵、
Δ
\Delta
Δ 成为输入的函数,让模型能够根据输入内容自适应地调整其行为
最终的整体流程:
结构细节如下:
其中的“选择性SSM(即Selective SSM)”具有以下属性:
- Recurrent SSM通过离散化创建循环SSM
- HiPPO对矩阵A进行初始化A以捕获长程依赖性
- 选择性扫描算法(Selective scan algorithm)选择性压缩信息
- 硬件感知算法(Hardware-aware algorithm)加速计算
Mamba和Trasformer,RNN相对的优势:
2 代码库目录结构
mamba
├── benchmarks
│ └── benchmark_generation_mamba_simple.py // 示例模型的推理脚本
├── csrc
│ └── selective_scan // 选择性扫描的c++实现
├── evals
│ └── lm_harness_eval.py
├── mamba_ssm
│ ├── models
│ │ ├── config_mamba.py
│ │ └── mixer_seq_simple.py // 使用mamba构建的一个完整的语言模型示例
│ ├── modules
│ │ └── mamba_simple.py // mamba block的实现
│ ├── ops
│ │ ├── triton
│ │ │ ├── layernorm.py
│ │ │ ├── selective_state_update.py
│ │ └── selective_scan_interface.py // 选择性SSM层的实现
│ ├── utils
│ │ ├── generation.py
│ │ └── hf.py
└── test
└── ops
├── triton
│ ├── test_selective_state_update.py
└──test_selective_scan.py
2.1 mamba_simple.py主体结构
2.1.1 Mamba类
class Mamba(nn.Module):
def __init__(
...
def forward(self, hidden_states, inference_params=None):
...
def step(self, hidden_states, conv_state, ssm_state):
...
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
...
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
...
(1)__init__函数
Mamba类/__init__函数负责各个模块的初始化
- 输入投影初始化:输入特征通过线性投影被扩展为两倍的内部特征维度,准备进行卷积和状态空间模型的计算。
- 卷积层:使用 1D 卷积操作,捕捉局部的时间特征,使用
d_conv
作为卷积核大小。 - 状态空间模型:通过
A_log
矩阵初始化状态矩阵,并通过跳跃连接D
来捕捉长期依赖。 - 时间常数初始化:时间常数
dt
通过dt_proj
层生成,并在初始化时通过 softplus 函数的逆函数来确保其在合理范围内。 - 输出层:通过线性投影,将扩展后的特征维度恢复到输入特征维度,完成一步处理后的输出。
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model # 输入特征维度
self.d_state = d_state # 状态空间模型的维度
self.d_conv = d_conv # 卷积核大小
self.expand = expand # 特征扩展倍数
self.d_inner = int(self.expand * self.d_model) # 扩展后的内部维度
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank # 自动确定时间常数的秩
self.use_fast_path = use_fast_path # 是否使用优化的快速路径
self.layer_idx = layer_idx # 当前层的索引
# 线性投影层,将输入特征扩展为内部维度的两倍
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
# 1D 卷积操作,卷积核大小为 `d_conv`,输入和输出通道数均为 `d_inner`
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
# 激活函数 SiLU (Sigmoid Linear Unit),一种常用的非线性激活
self.activation = "silu"
self.act = nn.SiLU()
# 再次线性投影,用于生成时间常数、B 和 C 矩阵
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
# 时间常数的线性投影层
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
# 初始化时间常数的投影层参数,确保初始化时方差保持稳定
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# 初始化 dt 投影的偏置,使得F.softplus(dt_bias)后的值在 `dt_min` 和 `dt_max` 之间
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# 使用 softplus 的逆函数来计算偏置
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# 我们的初始化会将所有的Linear.bias变成0,所以这里标记Linear.bias不需要重新初始化
self.dt_proj.bias._no_reinit = True
# S4D 初始化过程,状态矩阵 A 的对数
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # 对 A 取对数,保持精度fp32
self.A_log_temp = GwParameters(A_log)
self.A_log = self.A_log_temp.param
self.A_log._no_weight_decay = True # 防止权重衰减
# 跳跃连接参数 D
self.D_temp = GwParameters(torch.ones(self.d_inner, device=device))
self.D = self.D_temp.param
self.D._no_weight_decay = True
# 输出层的线性投影,将内部维度投影回原始特征维度
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
(2)forward函数
Mamba类/ forward 函数负责在前向传播过程中对输入的序列数据进行处理。它主要包括以下几步:
- 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
- 状态空间模型 (SSM):通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
- 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
- 时间常数的处理:通过
dt_proj
线性投影将特定的时间常数应用到特征上,适应不同时序的动态。 - 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。
[def forward(self, hidden_states, inference_params=None):
"""
主要思路:
1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。
hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 D
Returns: 返回与输入相同形状的输出
"""
batch, seqlen, dim = hidden_states.shape
# 初始化卷积和状态空间模型的状态为 None
conv_state, ssm_state = None, None
if inference_params is not None:# 如果推理参数不为空
# 从缓存中获取卷积和状态空间模型的状态
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# 调用 step 函数逐步更新状态,并返回输出
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBL
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None: # 如果存在偏置
# 将偏置加到投影结果上
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
A = -torch.exp(self.A_log.float()) # 计算 A 矩阵 (d_inner, d_state)
# 如果使用快速路径 (不使用推理状态且启用了快速内核)
# 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作
if self.use_fast_path and inference_params is None: # 不支持输出状态
# 使用快速内核进行计算
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else: # 如果不使用快速路径 (常规路径)
# 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))
x, z = xz.chunk(2, dim=1) # x 和 z 分别为输入投影的两部分
# 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)
if conv_state is not None:
# 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # 更新卷积状态 (B D W)
# 执行一维卷积并进行激活操作 (如果存在加速的卷积函数,则使用加速路径)
if causal_conv1d_fn is None: # 如果没有使用加速卷积函数
x = self.act(self.conv1d(x)[..., :seqlen]) # 常规卷积操作并激活
else: # 使用加速卷积函数
x = causal_conv1d_fn(
x=x, # 输入
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), # 调整卷积权重
bias=self.conv1d.bias, # 卷积偏置
activation=self.activation, # 激活函数 (silu 或 swish)
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
# 对卷积结果进行投影以生成 dt, B 和 C
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # 线性投影 (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, C
dt = self.dt_proj.weight @ dt.t() # 对时间常数 dt 进行线性变换
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) # 重塑为 (b, d, l)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 B
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 C
assert self.activation in ["silu", "swish"]
# 使用 selective_scan_fn 进行状态空间模型计算
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
# 如果存在状态空间模型的状态缓存
if ssm_state is not None:
y, last_state = y # 更新最后的状态
ssm_state.copy_(last_state) # 将更新后的状态缓存起来
# 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)
y = rearrange(y, "b d l -> b l d")
# 通过输出层的线性投影恢复原始特征维度
out = self.out_proj(y)
return out](<def forward(self, hidden_states, inference_params=None):
"""
主要思路:
1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。
hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 D
Returns: 返回与输入相同形状的输出
"""
batch, seqlen, dim = hidden_states.shape
# 初始化卷积和状态空间模型的状态为 None
conv_state, ssm_state = None, None
# TODO: inference_params这个参数有什么用?什么时候会被调整?
# ============================================== inference_params is True ==============================================
if inference_params is not None:# 如果推理参数不为空
# 从缓存中获取卷积和状态空间模型的状态
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset %3E 0:
# 调用 step 函数逐步更新状态,并返回输出
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# ======================================================================================================================
# 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBL
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None: # 如果存在偏置
# 将偏置加到投影结果上
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") # 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作
A = -torch.exp(self.A_log.float()) # 计算 A 矩阵 (d_inner, d_state)
# ============================================== inference_params is None ==============================================
# 如果使用快速路径,整体的思路和非快速一致,可以看else中的解释
if self.use_fast_path and inference_params is None: # (不使用推理状态且启用了快速内核)不支持输出状态
# 使用快速内核进行计算
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else: # 如果不使用快速路径 (常规路径)
# 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))
x, z = xz.chunk(2, dim=1) # x 和 z 分别为输入投影的两部分
# 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)
if conv_state is not None:
# 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # 更新卷积状态 (B D W)
# ===================== Conv部分+激活操作 (如果存在加速的卷积函数,则使用加速路径) ↓=======================
if causal_conv1d_fn is None: # 如果没有使用加速卷积函数
x = self.act(self.conv1d(x)[..., :seqlen]) # 常规卷积操作并激活
else: # 使用加速卷积函数
x = causal_conv1d_fn(
x=x, # 输入
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), # 调整卷积权重
bias=self.conv1d.bias, # 卷积偏置
activation=self.activation, # 激活函数 (silu 或 swish)
)
# ====================================== Conv部分+激活操作 ↑========================================
# ======================================== SSM 部分 ↓ ========================================
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
# ==对卷积结果进行投影以生成 dt, B 和 C==
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # 线性投影 (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, C
dt = self.dt_proj.weight @ dt.t() # 对时间常数 dt 进行线性变换
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) # 重塑为 (b, d, l)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 B
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 C
#====================================
assert self.activation in ["silu", "swish"]
# 使用 selective_scan_fn 进行状态空间模型计算
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
# 如果存在状态空间模型的状态缓存
if ssm_state is not None:
y, last_state = y # 更新最后的状态
ssm_state.copy_(last_state) # 将更新后的状态缓存起来
# ==================================== SSM 部分 ↑ =============================================
# 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)
y = rearrange(y, "b d l -> b l d")
# 通过输出层的线性投影恢复原始特征维度
out = self.out_proj(y)
return out>)
(3)step函数
Mamba类/ step 函数, 是一次完整的mamba块处理过程(inference_params不为空时执行)。负责在解码过程中处理单个时间步长的数据,并更新卷积状态和状态空间模型(SSM)的状态。主要逻辑分为两部分:
- 卷积步骤 (Conv Step):对输入进行卷积操作,用来捕捉局部的时序特征。
- 状态空间模型步骤 (SSM Step):利用状态空间模型进行时间序列的动态建模,捕捉序列中的长期依赖关系。
[def step(self, hidden_states, conv_state, ssm_state):
"""
执行过程中单步处理
hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度
conv_state: 卷积的状态缓存 (B, D, W)
ssm_state: 状态空间模型的状态缓存 (B, D, S)
Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state
"""
# 确保输入的第二维是1,即仅处理一个 token
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
# 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -> (B, 2D)
xz = self.in_proj(hidden_states.squeeze(1)) # (B, 2D)
# 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)
x, z = xz.chunk(2, dim=-1) # (B, D)
# 卷积操作
if causal_conv1d_update is None: # 如果没有使用加速卷积更新函数
# 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x # 将当前输入 x 放到状态缓存的最后一个位置
# 执行一维卷积,卷积结果为 (B, D)
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)
if self.conv1d.bias is not None: # 如果存在卷积偏置
x = x + self.conv1d.bias # 加上偏置
# 激活函数处理并将数据类型转回为原始输入的 dtype
x = self.act(x).to(dtype=dtype)
else: # 如果使用了加速的卷积更新函数
x = causal_conv1d_update(
x, # 输入
conv_state, # 卷积状态
rearrange(self.conv1d.weight, "d 1 w -> d w"), # 卷积权重
self.conv1d.bias, # 卷积偏置
self.activation, # 激活函数 (如 silu 或 swish)
)
# 对卷积后的 x 进行线性投影,得到 dt, B 和 C
x_db = self.x_proj(x) # (B, dt_rank + 2 * d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, C
# 时间常数 dt 通过线性投影更新,不加偏置
dt = F.linear(dt, self.dt_proj.weight) # (B, d_inner)
# 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# 状态空间模型步骤 (SSM Step)
if selective_state_update is None: # 如果没有使用选择性状态更新函数
# 离散化 A 和 B 矩阵
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) # 对 dt 进行 softplus 激活并加上偏置
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) # 计算离散化后的 A 矩阵
dB = torch.einsum("bd,bn->bdn", dt, B) # 计算离散化后的 B 矩阵
# 更新状态空间模型的状态
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
# 计算输出 y
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) # 状态乘以 C 得到输出
y = y + self.D.to(dtype) * x # 加上 D * x 的跳跃连接
y = y * self.act(z) # 与 z 经过激活函数的结果相乘 (B, D)
else: # 如果使用了选择性状态更新函数
y = selective_state_update( # 调用加速的选择性状态更新函数
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
# 将输出通过线性投影层
out = self.out_proj(y) # (B, D)
return out.unsqueeze(1), conv_state, ssm_state # 返回输出 (B, 1, D) 以及更新后的状态](<def step(self, hidden_states, conv_state, ssm_state):
"""
执行过程中单步处理
hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度
conv_state: 卷积的状态缓存 (B, D, W)
ssm_state: 状态空间模型的状态缓存 (B, D, S)
Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state
"""
# ===================== 输入预处理部分 ↓ =====================
# 确保输入的第二维是1,即仅处理一个 token
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
# 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -%3E (B, 2D)
xz = self.in_proj(hidden_states.squeeze(1)) # (B, 2D)
# 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)
x, z = xz.chunk(2, dim=-1) # (B, D)
# ===================== 输入预处理部分 ↑ =====================
# ================== 卷积部分 + 激活操作 ↓ =====================
if causal_conv1d_update is None: # 如果没有使用加速卷积更新函数
# 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x # 将当前输入 x 放到状态缓存的最后一个位置
# 执行一维卷积,卷积结果为 (B, D)
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)
if self.conv1d.bias is not None: # 如果存在卷积偏置
x = x + self.conv1d.bias # 加上偏置
# 激活函数处理并将数据类型转回为原始输入的 dtype
x = self.act(x).to(dtype=dtype)
else: # 如果使用了加速的卷积更新函数
x = causal_conv1d_update(
x, # 输入
conv_state, # 卷积状态
rearrange(self.conv1d.weight, "d 1 w -> d w"), # 卷积权重
self.conv1d.bias, # 卷积偏置
self.activation, # 激活函数 (如 silu 或 swish)
)
# ================== 卷积部分 + 激活操作 ↑ =====================
# ==================== SSM 部分 ↓ ==========================
# 对卷积后的 x 进行线性投影,得到 dt, B 和 C
x_db = self.x_proj(x) # (B, dt_rank + 2 * d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, C
# 时间常数 dt 通过线性投影更新,不加偏置
dt = F.linear(dt, self.dt_proj.weight) # (B, d_inner)
# 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# ====状态空间模型具体步骤 (SSM Step)====
if selective_state_update is None: # 如果没有使用选择性状态更新函数
# 离散化 A 和 B 矩阵
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) # 对 dt 进行 softplus 激活并加上偏置
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) # 计算离散化后的 A 矩阵
dB = torch.einsum("bd,bn->bdn", dt, B) # 计算离散化后的 B 矩阵
# 更新状态空间模型的状态
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
# 计算输出 y
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) # 状态乘以 C 得到输出
y = y + self.D.to(dtype) * x # 加上 D * x 的跳跃连接
# ======= mamba块右边的激活叠加 ========
y = y * self.act(z) # 与 z 经过激活函数的结果相乘 (B, D)
else: # 如果使用了选择性状态更新函数
y = selective_state_update( # 调用加速的选择性状态更新函数
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
# ===================== SSM 部分 ↑ =========================
# 将输出通过线性投影层
out = self.out_proj(y) # (B, D)
return out.unsqueeze(1), conv_state, ssm_state # 返回输出 (B, 1, D) 以及更新后的状态>)
此外,辅助函数 allocate_inference_cache
和 _get_states_from_cache
用于在推理过程中管理卷积状态和状态空间模型的状态缓存。
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""
为推理过程分配卷积和状态空间模型的状态缓存。
batch_size: 批次大小
max_seqlen: 最大序列长度
dtype: 数据类型 (可选)
"""
device = self.out_proj.weight.device # 获取设备 (CPU 或 GPU)
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype # 如果未指定 dtype,则使用卷积权重的 dtype
# 初始化卷积状态,形状为 (B, D, W),其中 W 是卷积窗口大小
conv_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype # 如果未指定 dtype,则使用 dt_proj 权重的 dtype
# 初始化状态空间模型的状态,形状为 (B, D, S),其中 S 是状态维度
ssm_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state # 返回分配的状态
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
"""
从缓存中获取卷积和状态空间模型的状态,若缓存不存在则初始化状态。
inference_params: 推理参数
batch_size: 批次大小
initialize_states: 是否初始化状态
"""
assert self.layer_idx is not None # 确保层索引存在
if self.layer_idx not in inference_params.key_value_memory_dict: # 如果缓存中没有当前层的状态
batch_shape = (batch_size,)
# 初始化卷积状态
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
# 初始化状态空间模型的状态
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
# 将初始化的状态存入缓存字典
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else: # 如果缓存中已有状态
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states: # 如果需要初始化状态
conv_state.zero_() # 将卷积状态清零
ssm_state.zero_() # 将状态空间模型的状态清零
return conv_state, ssm_state # 返回卷积和状态空间模型的状态
2.1.2 Block类
用于构建包含 Mamba
模块的更大网络。封装了 LayerNorm/RMSNorm 和残差连接的简单模块,以及Mamba的类。这个 Block 与传统的 Transformer 预归一化块稍有不同。
标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。
而这里的顺序是: Add -> LN -> Mixer ,并且返回的是 hidden_states 和 residual。
这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。
需要提供残差,除非是第一个块(第一个块不需要残差输入)。
......
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
"""
这是一个封装了 LayerNorm/RMSNorm 和残差连接的简单模块,包含了一个 mixer(Mamba) 类。
这个 Block 与传统的 Transformer 预归一化块稍有不同。
标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。
而这里的顺序是: Add -> LN -> Mixer,并且返回的是 hidden_states 和 residual。
这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。
需要提供残差,除非是第一个块(第一个块不需要残差输入)。
"""
super().__init__() # 调用父类的初始化方法,确保 nn.Module 的属性被正确初始化
self.residual_in_fp32 = residual_in_fp32 # 是否在 fp32 精度下处理残差
self.fused_add_norm = fused_add_norm # 是否启用残差加法和归一化融合
self.mixer = mixer_cls(dim) # 初始化 mixer 模块,该模块用于处理输入数据的混合
self.norm = norm_cls(dim) # 初始化归一化层,默认为 nn.LayerNorm
# 如果启用了融合归一化,需要检查 RMSNorm 是否正确导入,并验证 self.norm 是 LayerNorm 或 RMSNorm
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm 导入失败" # 确保 RMSNorm 导入成功
assert isinstance( # 确保 norm 是 LayerNorm 或 RMSNorm 类型
self.norm, (nn.LayerNorm, RMSNorm)
), "fused_add_norm 模式下仅支持 LayerNorm 和 RMSNorm"
def forward( # 定义前向传播方法
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
):
"""
将输入传递给编码层的前向传播方法。
参数:
hidden_states: 输入序列 (必须提供)。
residual: 输入的残差,如果 residual 为 None,hidden_states 直接作为残差使用。
"""
# 如果不启用 fused_add_norm,则进行标准的 LayerNorm 操作
if not self.fused_add_norm:
# 计算残差: 如果传入了 residual,将其加到 hidden_states 上,否则将 hidden_states 作为 residual
residual = (hidden_states + residual) if residual is not None else hidden_states
# 将 residual 经过 norm 归一化
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
# 如果 residual_in_fp32 为真,将 residual 转换为 float32 处理
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else: # 启用了 fused_add_norm 时,使用融合的归一化和加法操作
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
# 使用 fused_add_norm_fn 处理 hidden_states 和 residual
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True, # 使用预归一化
residual_in_fp32=self.residual_in_fp32, # 在 fp32 中计算残差
eps=self.norm.eps, # 归一化时使用的 epsilon 值,防止除零错误
)
# 将归一化后的 hidden_states 传递给 mixer 模块进行进一步处理
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
# 返回处理后的 hidden_states 和 residual
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
# 为推理分配缓存空间,用于存储中间状态,调用 mixer 的同名方法
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
2.2 官方完整实例
mixer_seq_simple.py 是官方的一个完整实例。
一个比较关键的函数是create_block函数,相当于使用mamba的一个接口。
create_block函数中包含了具体应该如何使用mamba_simple.py中的Block类
...
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
# 创建Mamba混合器类的偏函数,允许传递额外的配置和工厂关键字参数
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
# 根据是否使用RMSNorm,创建归一化层类的偏函数,并设置epsilon和其他工厂关键字参数
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
# 创建Block实例,使用前面定义的混合器和归一化类
block = Block(
d_model,
mixer_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
# 设置层索引,便于在模型中跟踪和配置此块
block.layer_idx = layer_idx
return block
...
(未完待续。。。)