【官方Mamba库】原理简述和代码解析

Mamba原理:参考链接1参考链接2参考链接3
Mamba代码:参考链接mamba-mini实现参考链接

官方mamba库链接: https://github.com/state-spaces/mamba
(注:以下以1.1.4版本mamba-ssm库的代码为例)

🥳🥳🥳欢迎各位大佬一起来交流讨论🥳🥳🥳

1 代码原理简述

熟悉原理的话可以直接看第二章代码

原理介绍摘录参考了链接,(这位大佬的博客应该是全网最详细的了,想摸清原理一定要读一下)

Mamba算法的本质是借鉴了现代控制理论中的状态空间模型SSM

1.1 原始结构——SSM

状态空间模型SSM的公式如下:
请添加图片描述

通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态。(A、B、C、D这4个矩阵是可学习的参数,是可以学习到的,学习好之后,在SSM中,矩阵A、B、C、D便固定不变了,但到后续mamba中则这4个矩阵可以随着输入不同而可变)
这一过程可以由如下图来表示:
请添加图片描述

1.2 结构改进——S4(Structured State Space for Sequences)

S4模型对SSM的改进有以下三点:

  1. 采用零阶保持,来进行连续化:由于SSM模型是针对连续函数的,但是在文本、图像等领域,数据都是离散的,所以我们需要将离散的点连续化,才能输入进SSM模型,最后再从连续的输出中采样离散的点来得到真正的输出
  2. 使用循环+卷积结构表示,从而能够并行训练,加快训练速度
  3. 使用HIPPO矩阵,以处理远程依赖

1.2.1 离散化

由于以上公式是针对连续变量的,单实际计算机相关的应用中往往离散的,所以经过采样零阶保持的方法,便有了以下离散化公式(这个公式很重要,在代码中会体现):
请添加图片描述

然后公式变成了这样:
请添加图片描述
在每个时间步,都会涉及到隐藏状态的更新(比如 h k h_{k} hk取决于 B ‾ x k \overline{\mathrm{B}}\mathrm{x}_{k} Bxk A ‾ h k − 1 \overline{\mathrm{A}}\mathrm{h}_{k-1} Ahk1的共同作用结果,然后通过 C h k Ch_k Chk预测输出 y k y_k yk),这样反复套娃,就得到了类似RNN的循环结构:
请添加图片描述
换个图表示更明显:
在这里插入图片描述

用公式来表示就是这样的:
y 2 = C h 2 = C ( A ˉ h 1 + B ˉ x 2 ) = C ( A ˉ ( A ˉ h 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ( A ˉ ⋅ B ˉ x 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ⋅ A ˉ ⋅ B ˉ x 0 + A ˉ ⋅ B ˉ x 1 + B ˉ x 2 ) = C ⋅ A ˉ 2 ⋅ B ˉ x 0 + C ⋅ A ˉ ⋅ B ˉ ⋅ x 1 + C ⋅ B ˉ x 2 \begin{aligned} y_2& =Ch_{2} \\ &=C\left(\bar{A}h_{1}+\bar{B}x_{2}\right) \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}h_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}\cdot\bar{B}x_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\cdot\bar{A}\cdot\bar{B}x_{0}+\bar{A}\cdot\bar{B}x_{1}+\bar{B}x_{2}\end{pmatrix} \\ &=C\cdot\bar{A}^{2}\cdot\bar{B}x_{0}+C\cdot\bar{A}\cdot\bar{B}\cdot x_{1}+C\cdot\bar{B}x_{2} \end{aligned} y2=Ch2=C(Aˉh1+Bˉx2)=C(Aˉ(Aˉh0+Bˉx1)+Bˉx2)=C(Aˉ(AˉBˉx0+Bˉx1)+Bˉx2)=C(AˉAˉBˉx0+AˉBˉx1+Bˉx2)=CAˉ2Bˉx0+CAˉBˉx1+CBˉx2
由此类推:
y 3 = C A A A B ‾ x 0 + C A A B ‾ x 1 + C A B ‾ x 2 + C B ‾ x 3 y_{3}=\mathbf{C\overline{AAAB}}x_{0}+\mathbf{C\overline{AAB}}x_{1}+\mathbf{C\overline{AB}}x_{2}+\mathbf{C\overline{B}}x_{3} y3=CAAABx0+CAABx1+CABx2+CBx3
y k = C A ˉ k B ˉ x 0 + C A ˉ k − 1 B ˉ x 1 + ⋯ + C A ˉ B ˉ x k − 1 + C B ˉ x k y_{k}=C\bar{A}^{k}\bar{B}x_{0}+C\bar{A}^{k-1}\bar{B}x_{1}+\cdots+C\bar{A}\bar{B}x_{k-1}+C\bar{B}x_{k} yk=CAˉkBˉx0+CAˉk1Bˉx1++CAˉBˉxk1+CBˉxk
这样看来循环结构就更明显了。

1.2.2HiPPO

HiPPO是一种解决长距离依赖问题的方法,用于构建并调整状态矩阵A使其逼近最优解,这种方法比把A初始化为随机矩阵要好得多
请添加图片描述

1.3 最终版本——Mamba(又称S6或selective SSMs)

Mamba模型对于S4模型的改进有以下三点:

  1. 参数化矩阵:对输入信息进行有选择性的处理,从而得到类似Attention的效果,即不同的输入拥有不同的状态,token信息
  2. 硬件感知算法,并行化–选择扫描算法,加快训练推理速度
  3. 更简化的SSM模型架构

在最终的Mamaba中,作者让B矩阵、C矩阵、 Δ \Delta Δ 成为输入的函数,让模型能够根据输入内容自适应地调整其行为
请添加图片描述

最终的整体流程:
请添加图片描述
结构细节如下:
请添加图片描述

其中的“选择性SSM(即Selective SSM)”具有以下属性:

  1. Recurrent SSM通过离散化创建循环SSM
  2. HiPPO对矩阵A进行初始化A以捕获长程依赖性
  3. 选择性扫描算法(Selective scan algorithm)选择性压缩信息
  4. 硬件感知算法(Hardware-aware algorithm)加速计算

Mamba和Trasformer,RNN相对的优势:
请添加图片描述

2 代码库目录结构

mamba
├── benchmarks
│ 	└── benchmark_generation_mamba_simple.py  // 示例模型的推理脚本
├── csrc
│ 	└── selective_scan  // 选择性扫描的c++实现
├── evals
│ 	└── lm_harness_eval.py
├── mamba_ssm
│ 	├── models
│   │   ├── config_mamba.py
│   │   └── mixer_seq_simple.py  // 使用mamba构建的一个完整的语言模型示例
│ 	├── modules
│   │   └── mamba_simple.py   // mamba block的实现
│ 	├── ops
│   │   ├── triton
│   │   │   ├── layernorm.py
│   │   │   ├── selective_state_update.py
│   │   └── selective_scan_interface.py   // 选择性SSM层的实现
│ 	├── utils
│   │   ├── generation.py
│   │   └── hf.py
└── test
		└── ops
		    ├── triton
		    │		├── test_selective_state_update.py
        └──test_selective_scan.py


2.1 mamba_simple.py主体结构

2.1.1 Mamba类

class Mamba(nn.Module):
	def __init__(
	...
	def forward(self, hidden_states, inference_params=None):
	...
	def step(self, hidden_states, conv_state, ssm_state):
	...
	def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
	...
	def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
	...

(1)__init__函数

Mamba类/__init__函数负责各个模块的初始化

  • 输入投影初始化:输入特征通过线性投影被扩展为两倍的内部特征维度,准备进行卷积和状态空间模型的计算。
  • 卷积层:使用 1D 卷积操作,捕捉局部的时间特征,使用 d_conv 作为卷积核大小。
  • 状态空间模型:通过 A_log 矩阵初始化状态矩阵,并通过跳跃连接 D 来捕捉长期依赖。
  • 时间常数初始化:时间常数 dt 通过 dt_proj 层生成,并在初始化时通过 softplus 函数的逆函数来确保其在合理范围内。
  • 输出层:通过线性投影,将扩展后的特征维度恢复到输入特征维度,完成一步处理后的输出。
def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True,  # Fused kernel options
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model  # 输入特征维度
        self.d_state = d_state  # 状态空间模型的维度
        self.d_conv = d_conv    # 卷积核大小
        self.expand = expand    # 特征扩展倍数
        self.d_inner = int(self.expand * self.d_model)  # 扩展后的内部维度
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank  # 自动确定时间常数的秩
        self.use_fast_path = use_fast_path  # 是否使用优化的快速路径
        self.layer_idx = layer_idx  # 当前层的索引

        # 线性投影层,将输入特征扩展为内部维度的两倍
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

        # 1D 卷积操作,卷积核大小为 `d_conv`,输入和输出通道数均为 `d_inner`
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        # 激活函数 SiLU (Sigmoid Linear Unit),一种常用的非线性激活
        self.activation = "silu"
        self.act = nn.SiLU()

        # 再次线性投影,用于生成时间常数、B 和 C 矩阵
        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        # 时间常数的线性投影层
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        # 初始化时间常数的投影层参数,确保初始化时方差保持稳定
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # 初始化 dt 投影的偏置,使得F.softplus(dt_bias)后的值在 `dt_min` 和 `dt_max` 之间
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        
        # 使用 softplus 的逆函数来计算偏置
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # 我们的初始化会将所有的Linear.bias变成0,所以这里标记Linear.bias不需要重新初始化
        self.dt_proj.bias._no_reinit = True

        # S4D 初始化过程,状态矩阵 A 的对数
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # 对 A 取对数,保持精度fp32
        self.A_log_temp = GwParameters(A_log)
        self.A_log = self.A_log_temp.param
        self.A_log._no_weight_decay = True # 防止权重衰减

        # 跳跃连接参数 D
        self.D_temp = GwParameters(torch.ones(self.d_inner, device=device))
        self.D = self.D_temp.param
        self.D._no_weight_decay = True

        # 输出层的线性投影,将内部维度投影回原始特征维度
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
(2)forward函数

Mamba类/ forward 函数负责在前向传播过程中对输入的序列数据进行处理。它主要包括以下几步:

  1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
  2. 状态空间模型 (SSM):通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
  3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
  4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
  5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。
[def forward(self, hidden_states, inference_params=None):
        """
        主要思路:
        1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
        2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
        3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
        4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
        5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。

        hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 D
        Returns: 返回与输入相同形状的输出
        """
        batch, seqlen, dim = hidden_states.shape

        # 初始化卷积和状态空间模型的状态为 None
        conv_state, ssm_state = None, None

        if inference_params is not None:# 如果推理参数不为空
        # 从缓存中获取卷积和状态空间模型的状态
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # 调用 step 函数逐步更新状态,并返回输出
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out

        # 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBL
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:  # 如果存在偏置
        # 将偏置加到投影结果上
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())  # 计算 A 矩阵 (d_inner, d_state)

        # 如果使用快速路径 (不使用推理状态且启用了快速内核)
        # 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作
        if self.use_fast_path and inference_params is None:  # 不支持输出状态
            # 使用快速内核进行计算
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else: # 如果不使用快速路径 (常规路径)
            # 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))
            x, z = xz.chunk(2, dim=1)  # x 和 z 分别为输入投影的两部分

            # 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)
            if conv_state is not None:
                # 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # 更新卷积状态 (B D W)

            # 执行一维卷积并进行激活操作 (如果存在加速的卷积函数,则使用加速路径)
            if causal_conv1d_fn is None:  # 如果没有使用加速卷积函数
                x = self.act(self.conv1d(x)[..., :seqlen])  # 常规卷积操作并激活
            else:  # 使用加速卷积函数
                x = causal_conv1d_fn(
                    x=x,                                    # 输入
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 调整卷积权重
                    bias=self.conv1d.bias,                  # 卷积偏置
                    activation=self.activation,             # 激活函数 (silu 或 swish)
                )


            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            # 对卷积结果进行投影以生成 dt, B 和 C
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # 线性投影 (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, C
            dt = self.dt_proj.weight @ dt.t()  # 对时间常数 dt 进行线性变换
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)  # 重塑为 (b, d, l)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 B
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 C

            
            assert self.activation in ["silu", "swish"]

            # 使用 selective_scan_fn 进行状态空间模型计算
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )

            # 如果存在状态空间模型的状态缓存
            if ssm_state is not None:
                y, last_state = y  # 更新最后的状态
                ssm_state.copy_(last_state)  # 将更新后的状态缓存起来
            
            # 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)
            y = rearrange(y, "b d l -> b l d")
            # 通过输出层的线性投影恢复原始特征维度
            out = self.out_proj(y)
        return out](<def forward(self, hidden_states, inference_params=None):
        """
        主要思路:
        1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
        2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
        3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
        4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
        5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。

        hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 D
        Returns: 返回与输入相同形状的输出
        """
        batch, seqlen, dim = hidden_states.shape

        # 初始化卷积和状态空间模型的状态为 None
        conv_state, ssm_state = None, None


# TODO: inference_params这个参数有什么用?什么时候会被调整?
# ============================================== inference_params is True ==============================================
        if inference_params is not None:# 如果推理参数不为空
        # 从缓存中获取卷积和状态空间模型的状态
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset %3E 0:
                # 调用 step 函数逐步更新状态,并返回输出
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out
# ======================================================================================================================



        # 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBL
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:  # 如果存在偏置
        # 将偏置加到投影结果上
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") # 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作

        A = -torch.exp(self.A_log.float())  # 计算 A 矩阵 (d_inner, d_state)



# ============================================== inference_params is None ==============================================
        # 如果使用快速路径,整体的思路和非快速一致,可以看else中的解释
        if self.use_fast_path and inference_params is None:  # (不使用推理状态且启用了快速内核)不支持输出状态
            # 使用快速内核进行计算
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else: # 如果不使用快速路径 (常规路径)
            # 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))
            x, z = xz.chunk(2, dim=1)  # x 和 z 分别为输入投影的两部分

            # 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)
            if conv_state is not None:
                # 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # 更新卷积状态 (B D W)

        # ===================== Conv部分+激活操作 (如果存在加速的卷积函数,则使用加速路径) ↓=======================
            if causal_conv1d_fn is None:  # 如果没有使用加速卷积函数
                x = self.act(self.conv1d(x)[..., :seqlen])  # 常规卷积操作并激活
            else:  # 使用加速卷积函数
                x = causal_conv1d_fn(
                    x=x,                                    # 输入
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 调整卷积权重
                    bias=self.conv1d.bias,                  # 卷积偏置
                    activation=self.activation,             # 激活函数 (silu 或 swish)
                )
        # ====================================== Conv部分+激活操作 ↑========================================


        # ======================================== SSM 部分 ↓ ========================================
            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.

            # ==对卷积结果进行投影以生成 dt, B 和 C==
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # 线性投影 (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, C
            dt = self.dt_proj.weight @ dt.t()  # 对时间常数 dt 进行线性变换
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)  # 重塑为 (b, d, l)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 B
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 C
            #====================================

            
            assert self.activation in ["silu", "swish"]

            # 使用 selective_scan_fn 进行状态空间模型计算
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )

            # 如果存在状态空间模型的状态缓存
            if ssm_state is not None:
                y, last_state = y  # 更新最后的状态
                ssm_state.copy_(last_state)  # 将更新后的状态缓存起来
        # ==================================== SSM 部分 ↑ =============================================


            # 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)
            y = rearrange(y, "b d l -> b l d")
            # 通过输出层的线性投影恢复原始特征维度
            out = self.out_proj(y)
        return out>)
(3)step函数

Mamba类/ step 函数, 是一次完整的mamba块处理过程(inference_params不为空时执行)。负责在解码过程中处理单个时间步长的数据,并更新卷积状态和状态空间模型(SSM)的状态。主要逻辑分为两部分:

  1. 卷积步骤 (Conv Step):对输入进行卷积操作,用来捕捉局部的时序特征。
  2. 状态空间模型步骤 (SSM Step):利用状态空间模型进行时间序列的动态建模,捕捉序列中的长期依赖关系。
[def step(self, hidden_states, conv_state, ssm_state):
        """
        执行过程中单步处理
        hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度
        conv_state: 卷积的状态缓存 (B, D, W)
        ssm_state: 状态空间模型的状态缓存 (B, D, S)
        Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state
        """
        # 确保输入的第二维是1,即仅处理一个 token
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"

        # 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -> (B, 2D)
        xz = self.in_proj(hidden_states.squeeze(1))  # (B, 2D)
        
        # 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)
        x, z = xz.chunk(2, dim=-1)  # (B, D)


        # 卷积操作
        if causal_conv1d_update is None:  # 如果没有使用加速卷积更新函数
            # 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
            conv_state[:, :, -1] = x  # 将当前输入 x 放到状态缓存的最后一个位置

            # 执行一维卷积,卷积结果为 (B, D)
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)
            if self.conv1d.bias is not None:  # 如果存在卷积偏置
                x = x + self.conv1d.bias  # 加上偏置
            # 激活函数处理并将数据类型转回为原始输入的 dtype
            x = self.act(x).to(dtype=dtype)
        else:  # 如果使用了加速的卷积更新函数
            x = causal_conv1d_update(
                x,                     # 输入
                conv_state,            # 卷积状态
                rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 卷积权重
                self.conv1d.bias,      # 卷积偏置
                self.activation,       # 激活函数 (如 silu 或 swish)
            )

        # 对卷积后的 x 进行线性投影,得到 dt, B 和 C
        x_db = self.x_proj(x)  # (B, dt_rank + 2 * d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, C

        # 时间常数 dt 通过线性投影更新,不加偏置
        dt = F.linear(dt, self.dt_proj.weight)  # (B, d_inner)

        # 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # 状态空间模型步骤 (SSM Step)
        if selective_state_update is None:  # 如果没有使用选择性状态更新函数
            # 离散化 A 和 B 矩阵
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))  # 对 dt 进行 softplus 激活并加上偏置
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))  # 计算离散化后的 A 矩阵
            dB = torch.einsum("bd,bn->bdn", dt, B)  # 计算离散化后的 B 矩阵
            # 更新状态空间模型的状态
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            # 计算输出 y
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)  # 状态乘以 C 得到输出
            y = y + self.D.to(dtype) * x  # 加上 D * x 的跳跃连接
            y = y * self.act(z)  # 与 z 经过激活函数的结果相乘 (B, D)
        else:  # 如果使用了选择性状态更新函数
            y = selective_state_update(  # 调用加速的选择性状态更新函数
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

         # 将输出通过线性投影层
        out = self.out_proj(y)  # (B, D)
        return out.unsqueeze(1), conv_state, ssm_state  # 返回输出 (B, 1, D) 以及更新后的状态](<def step(self, hidden_states, conv_state, ssm_state):
        """
        执行过程中单步处理
        hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度
        conv_state: 卷积的状态缓存 (B, D, W)
        ssm_state: 状态空间模型的状态缓存 (B, D, S)
        Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state
        """
        # ===================== 输入预处理部分 ↓ =====================
        # 确保输入的第二维是1,即仅处理一个 token
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"

        # 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -%3E (B, 2D)
        xz = self.in_proj(hidden_states.squeeze(1))  # (B, 2D)
        
        # 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)
        x, z = xz.chunk(2, dim=-1)  # (B, D)
        # ===================== 输入预处理部分 ↑ =====================

        # ================== 卷积部分 + 激活操作 ↓ =====================
        if causal_conv1d_update is None:  # 如果没有使用加速卷积更新函数
            # 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
            conv_state[:, :, -1] = x  # 将当前输入 x 放到状态缓存的最后一个位置

            # 执行一维卷积,卷积结果为 (B, D)
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)
            if self.conv1d.bias is not None:  # 如果存在卷积偏置
                x = x + self.conv1d.bias  # 加上偏置
            # 激活函数处理并将数据类型转回为原始输入的 dtype
            x = self.act(x).to(dtype=dtype)
        else:  # 如果使用了加速的卷积更新函数
            x = causal_conv1d_update(
                x,                     # 输入
                conv_state,            # 卷积状态
                rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 卷积权重
                self.conv1d.bias,      # 卷积偏置
                self.activation,       # 激活函数 (如 silu 或 swish)
            )
        # ================== 卷积部分 + 激活操作 ↑ =====================


        # ====================   SSM 部分 ↓ ==========================
        # 对卷积后的 x 进行线性投影,得到 dt, B 和 C
        x_db = self.x_proj(x)  # (B, dt_rank + 2 * d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, C

        # 时间常数 dt 通过线性投影更新,不加偏置
        dt = F.linear(dt, self.dt_proj.weight)  # (B, d_inner)

        # 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # ====状态空间模型具体步骤 (SSM Step)====
        if selective_state_update is None:  # 如果没有使用选择性状态更新函数
            # 离散化 A 和 B 矩阵
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))  # 对 dt 进行 softplus 激活并加上偏置
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))  # 计算离散化后的 A 矩阵
            dB = torch.einsum("bd,bn->bdn", dt, B)  # 计算离散化后的 B 矩阵
            # 更新状态空间模型的状态
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            # 计算输出 y
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)  # 状态乘以 C 得到输出
            y = y + self.D.to(dtype) * x  # 加上 D * x 的跳跃连接

            # ======= mamba块右边的激活叠加 ========
            y = y * self.act(z)  # 与 z 经过激活函数的结果相乘 (B, D)

        else:  # 如果使用了选择性状态更新函数
            y = selective_state_update(  # 调用加速的选择性状态更新函数
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )
        # =====================  SSM 部分 ↑ =========================

         # 将输出通过线性投影层
        out = self.out_proj(y)  # (B, D)
        return out.unsqueeze(1), conv_state, ssm_state  # 返回输出 (B, 1, D) 以及更新后的状态>)

此外,辅助函数 allocate_inference_cache_get_states_from_cache 用于在推理过程中管理卷积状态和状态空间模型的状态缓存。

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
	"""
	为推理过程分配卷积和状态空间模型的状态缓存。
	batch_size: 批次大小
	max_seqlen: 最大序列长度
	dtype: 数据类型 (可选)
	"""
	device = self.out_proj.weight.device  # 获取设备 (CPU 或 GPU)
	conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype  # 如果未指定 dtype,则使用卷积权重的 dtype
	# 初始化卷积状态,形状为 (B, D, W),其中 W 是卷积窗口大小
	conv_state = torch.zeros(
		batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
	)
	ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype  # 如果未指定 dtype,则使用 dt_proj 权重的 dtype
	# 初始化状态空间模型的状态,形状为 (B, D, S),其中 S 是状态维度
	ssm_state = torch.zeros(
		batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
	)
	return conv_state, ssm_state  # 返回分配的状态

def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
	"""
	从缓存中获取卷积和状态空间模型的状态,若缓存不存在则初始化状态。
	inference_params: 推理参数
	batch_size: 批次大小
	initialize_states: 是否初始化状态
	"""
	assert self.layer_idx is not None  # 确保层索引存在
	if self.layer_idx not in inference_params.key_value_memory_dict:  # 如果缓存中没有当前层的状态
		batch_shape = (batch_size,)
		# 初始化卷积状态
		conv_state = torch.zeros(
			batch_size,
			self.d_model * self.expand,
			self.d_conv,
			device=self.conv1d.weight.device,
			dtype=self.conv1d.weight.dtype,
		)
		# 初始化状态空间模型的状态
		ssm_state = torch.zeros(
			batch_size,
			self.d_model * self.expand,
			self.d_state,
			device=self.dt_proj.weight.device,
			dtype=self.dt_proj.weight.dtype,
			# dtype=torch.float32,
		)
		# 将初始化的状态存入缓存字典
		inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
	else:  # 如果缓存中已有状态
		conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
		# TODO: What if batch size changes between generation, and we reuse the same states?
		if initialize_states:  # 如果需要初始化状态
			conv_state.zero_()  # 将卷积状态清零
			ssm_state.zero_()  # 将状态空间模型的状态清零
	return conv_state, ssm_state  # 返回卷积和状态空间模型的状态

2.1.2 Block类

用于构建包含 Mamba 模块的更大网络。封装了 LayerNorm/RMSNorm 和残差连接的简单模块,以及Mamba的类。这个 Block 与传统的 Transformer 预归一化块稍有不同。
标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。
而这里的顺序是: Add -> LN -> Mixer ,并且返回的是 hidden_states 和 residual。
这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。
需要提供残差,除非是第一个块(第一个块不需要残差输入)。

......

class Block(nn.Module): 
    def __init__(
        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
    ):
        """
        这是一个封装了 LayerNorm/RMSNorm 和残差连接的简单模块,包含了一个 mixer(Mamba) 类。
        
        这个 Block 与传统的 Transformer 预归一化块稍有不同。
        标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。
        而这里的顺序是: Add -> LN -> Mixer,并且返回的是 hidden_states 和 residual。

        这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。
        需要提供残差,除非是第一个块(第一个块不需要残差输入)。
        """
        super().__init__()  # 调用父类的初始化方法,确保 nn.Module 的属性被正确初始化
        self.residual_in_fp32 = residual_in_fp32  # 是否在 fp32 精度下处理残差
        self.fused_add_norm = fused_add_norm  # 是否启用残差加法和归一化融合
        self.mixer = mixer_cls(dim)  # 初始化 mixer 模块,该模块用于处理输入数据的混合
        self.norm = norm_cls(dim)  # 初始化归一化层,默认为 nn.LayerNorm

        # 如果启用了融合归一化,需要检查 RMSNorm 是否正确导入,并验证 self.norm 是 LayerNorm 或 RMSNorm
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm 导入失败"  # 确保 RMSNorm 导入成功
            assert isinstance(  # 确保 norm 是 LayerNorm 或 RMSNorm 类型
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "fused_add_norm 模式下仅支持 LayerNorm 和 RMSNorm"

    def forward(  # 定义前向传播方法
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
    ):
        """
        将输入传递给编码层的前向传播方法。

        参数:
            hidden_states: 输入序列 (必须提供)。
            residual: 输入的残差,如果 residual 为 None,hidden_states 直接作为残差使用。
        """
        # 如果不启用 fused_add_norm,则进行标准的 LayerNorm 操作
        if not self.fused_add_norm:
            # 计算残差: 如果传入了 residual,将其加到 hidden_states 上,否则将 hidden_states 作为 residual
            residual = (hidden_states + residual) if residual is not None else hidden_states

            # 将 residual 经过 norm 归一化
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))

            # 如果 residual_in_fp32 为真,将 residual 转换为 float32 处理
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:  # 启用了 fused_add_norm 时,使用融合的归一化和加法操作
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            # 使用 fused_add_norm_fn 处理 hidden_states 和 residual
            hidden_states, residual = fused_add_norm_fn(
                hidden_states,
                self.norm.weight,
                self.norm.bias,
                residual=residual,
                prenorm=True,  # 使用预归一化
                residual_in_fp32=self.residual_in_fp32,  # 在 fp32 中计算残差
                eps=self.norm.eps,  # 归一化时使用的 epsilon 值,防止除零错误
            )

        # 将归一化后的 hidden_states 传递给 mixer 模块进行进一步处理
        hidden_states = self.mixer(hidden_states, inference_params=inference_params)

        # 返回处理后的 hidden_states 和 residual
        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        # 为推理分配缓存空间,用于存储中间状态,调用 mixer 的同名方法
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

2.2 官方完整实例

mixer_seq_simple.py 是官方的一个完整实例。
一个比较关键的函数是create_block函数,相当于使用mamba的一个接口。
create_block函数中包含了具体应该如何使用mamba_simple.py中的Block类

...
def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    # 创建Mamba混合器类的偏函数,允许传递额外的配置和工厂关键字参数
    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    # 根据是否使用RMSNorm,创建归一化层类的偏函数,并设置epsilon和其他工厂关键字参数
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    # 创建Block实例,使用前面定义的混合器和归一化类
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    # 设置层索引,便于在模型中跟踪和配置此块
    block.layer_idx = layer_idx
    return block
...

(未完待续。。。)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值