MiniMax-01中Lightning Attention的由来(线性注意力进化史)

部署运行你感兴趣的模型镜像

引言

MiniMax-01: Scaling Foundation Models with Lightning Attention表明自己是第一个将线性注意力应用到如此大规模的模型,他所使用的核心技术就是Lightning Attention。

那为什么线性注意力20年在文章Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention中就提出了,现在才出第一个线性注意力的大模型呢?

本文就从线性注意力机制入手,详细探讨其起源、存在的显著局限性,以及Lightning Attention的具体实现细节。

原始注意力

现在主流的有两类模型,一种是应用双向注意力的bert类模型,另一种是应用单向注意力的gpt类模型,他们所使用的注意力其实是有细微差别的。

  • 双向注意力(bert类),就是传统认知中标准的注意力

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKT)V

  • 单向注意力(因果模型,gpt类),只能看到当前和前面的token,所有要在softmax之前乘上一个掩码矩阵,M为单向掩码矩阵

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T ⊙ M d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T\odot M}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKTM)V

其中Q、K、V每个矩阵的维度都是[n, d],即[序列长度,隐层维度],此时 Q K T QK^T QKT的维度是[n, n],所以整体复杂度是 O ( n 2 d ) O(n^2d) O(n2d)。其中d是固定大小, n 2 n^2 n2随着序列长度平方增加,就主导了整体的复杂度。

线性注意力

原始出处:Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

注意力计算可抽象成如下形式,sim表示可以使用任何的相似度函数,不一定非要softmax,可以是类似于多项式注意力或RBF核注意力等

V i ′ = ∑ j = 1 N sim ⁡ ( Q i , K j ) V j ∑ j = 1 N sim ⁡ ( Q i , K j ) . V_i^{^{\prime}}=\frac{\sum_{j=1}^{N}\operatorname{sim}\left(Q_i,K_j\right)V_j}{\sum_{j=1}^{N}\operatorname{sim}\left(Q_i,K_j\right)}. Vi=j=1Nsim(Qi,Kj)j=1Nsim(Qi,Kj)Vj.

但相似度函数也是有要求的,需要施加唯一约束sim(·)非负。基于这个条件,给定这样一个内核 ϕ ( x ) \phi(x) ϕ(x),可以将上式重写为

V i = ∑ j = 1 N ϕ ( Q i ) T ϕ ( K j ) V j ∑ j = 1 N ϕ ( Q i ) T ϕ ( K j ) V_i = \frac{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j) V_j}{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j)} Vi=j=1Nϕ(Qi)Tϕ(Kj)j=1Nϕ(Qi)Tϕ(Kj)Vj

利用矩阵乘法结合律得到

V i = ϕ ( Q i ) T ∑ j = 1 N ϕ ( K j ) V j T ϕ ( Q i ) T ∑ j = 1 N ϕ ( K j ) V_i = \frac{ \phi(Q_i)^T \sum_{j=1}^N\phi(K_j) V_j^T}{ \phi(Q_i)^T \sum_{j=1}^N\phi(K_j)} Vi=ϕ(Qi)Tj=1Nϕ(Kj)ϕ(Qi)Tj=1Nϕ(Kj)VjT

为简化理解可写成如下形式

Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) Attention(Q,K,V)=(ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)

深层理解:每个时间步的K和V可以提前计算并作为单个向量存储下来,在推理生成时直接用Q乘以每个时间步的KV,简单情况下KV cache的缓存向量都变少了,可以像RNN一样每次预测的时间都几乎是恒定的

注意:此时使用的一般都是绝对位置编码,Q、K矩阵没有乘以参数矩阵

此时 ϕ ( K ) T V \phi(K)^TV ϕ(K)TV的复杂度是 O ( d 2 ) O(d^2) O(d2),所以整体复杂度变成了 O ( n d 2 ) O(nd^2) O(nd2),随着序列长度n线性增长,此时就是线性注意力了。

(可选):通常线性注意力的公式还有如下形式

O = Δ − 1 ∗ ( Q ∗ K T ∗ V ) O = Δ^{-1} * (Q * K^T * V) O=Δ1(QKTV)

(可选)其中,Δ起到了归一化的作用。Δ的每个对角元素是 K T ∗ 1 K^T*1 KT1的值,这反映了每个键向量的重要程度。将 Δ − 1 Δ^{-1} Δ1乘到结果上,就相当于对注意力输出进行了逆归一化。相当于只对K归一化,Q本身就是一个合适的查询向量,不需要归一化。

因果模型存在的问题

注意上面的线性注意力是类bert模型的情况下,并没有与掩码矩阵相乘,此时可以顺畅的先右乘来降低复杂度。但现在的大模型都是生成模型,使用的因果模型结构,都是单向注意力,就必须要乘以掩码矩阵,所以不能顺畅的右乘了。
左乘线性注意力公式如下,输出为O,每个step的输出为当前的 q t q_t qt乘以前面的 k j k_j kj,再乘以 v j v_j vj累加求和。此时 Q K T QK^T QKT可以正常进行矩阵运算,然后使用 ⊙ \odot (Hadamard Product)进行逐元素相乘,得到掩码后的矩阵。

O = ( Q K T ⊙ M ) V O=(QK^T\odot M)V O=(QKTM)V

o t = ∑ j = 1 t ( q t T k j ) v j o_t=\sum_{j=1}^t(q_t^Tk_j)v_j ot=j=1t(qtTkj)vj

此时注意,上面公式的运算涉及 ⊙ \odot ,它不适用于矩阵乘法交换律和结合律,即无法 Q ( K T ⊙ M V ) Q(K^T\odot MV) Q(KTMV) ⊙ \odot 是逐元素相乘,所以两个矩阵的维度必须相同,即使将M的位置放到前面, K T V K^TV KTV的维度是[d, d],也无法与M逐元素相乘。

累加求和操作的限制

双向注意力模型(bert)中使用的线性注意力如下,可以先算KV

( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)

QKV的维度都为[n, d],这里假设序列长度为4,双向和单向注意力如下图

在这里插入图片描述

  • 双向注意力计算
    K和V的矩阵如下,得到的 K T V K^TV KTV的维度是[d, d]

K T = [ k 1 T k 2 T k 3 T k 4 T ] = [ k 11 k 21 k 31 k 41 k 12 k 22 k 32 k 42 ⋮ ⋮ ⋮ ⋮ k 1 d k 2 d k 3 d k 4 d ] K^{T}= \begin{bmatrix} k_{1}^T & k_{2}^T & k_{3}^T & k_{4}^T \\ \end{bmatrix}= \begin{bmatrix} k_{11} & k_{21} & k_{31} & k_{41} \\ k_{12} & k_{22} & k_{32} & k_{42} \\ \vdots & \vdots & \vdots & \vdots \\ k_{1d} & k_{2d} & k_{3d} & k_{4d}\\ \end{bmatrix} KT=[k1Tk2Tk3Tk4T]= k11k12k1dk21k22k2dk31k32k3dk41k42k4d

V = [ v 1 v 2 v 3 v 4 ] = [ v 11 v 12 . . . v 1 d v 21 v 22 . . . v 2 d v 31 v 32 . . . v 3 d v 41 v 42 . . . v 4 d ] V= \begin{bmatrix} v_{1} \\ v_{2} \\ v_{3} \\ v_{4} \\ \end{bmatrix}= \begin{bmatrix} v_{11} & v_{12} & ... & v_{1d} \\ v_{21} & v_{22} & ... & v_{2d} \\ v_{31} & v_{32} & ... & v_{3d} \\ v_{41} & v_{42} & ... & v_{4d} \end{bmatrix} V= v1v2v3v4 = v11v21v31v41v12v22v32v42............v1dv2dv3dv4d

K T V = [ k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ] = [ [ K T V ] 1 [ K T V ] 2 . . . [ K T V ] d ] K^{T}V= \begin{bmatrix} k_{1}^Tv_1 + k_{2}^Tv_2 + k_{3}^Tv_3 + k_{4}^Tv_4 \\ \end{bmatrix}= \begin{bmatrix} [K^{T}V]_{1} & [K^{T}V]_{2} & ... & [K^{T}V]_{d} \end{bmatrix} KTV=[k1Tv1+k2Tv2+k3Tv3+k4Tv4]=[[KTV]1[KTV]2...[KTV]d]

此时计算 q 3 q_3 q3的注意力输出就可以使用以下方法。注意这是点积,q3是一个向量, K T V K^{T}V KTV是一个矩阵,向量在与矩阵点积的时候会进行广播拓展,复制成多份分别与矩阵中的向量点积。 [ K T V ] 1 [K^{T}V]_{1} [KTV]1是一个向量, q 3 [ K T V ] 1 q_3[K^{T}V]_{1} q3[KTV]1点积后会得到一个值,所以 q 3 K T V q_3K^{T}V q3KTV最终的结果是一个向量,长度为隐层维度d。

q 3 K T V = q 3 [ [ K T V ] 1 [ K T V ] 2 . . . [ K T V ] d ] = [ q 3 [ K T V ] 1 q 3 [ K T V ] 2 . . . q 3 [ K T V ] d ] q_3K^{T}V= q_3 \begin{bmatrix} [K^{T}V]_{1} & [K^{T}V]_{2} & ... & [K^{T}V]_{d} \end{bmatrix}= \begin{bmatrix} q_3[K^{T}V]_{1} & q_3[K^{T}V]_{2} & ... & q_3[K^{T}V]_{d} \end{bmatrix} q3KTV=q3[[KTV]1[KTV]2...[KTV]d]=[q3[KTV]1q3[KTV]2...q3[KTV]d]

也可以使用以下代码测试

import torch
q3 = torch.tensor([1, 2, 3, 4, 5, 6])
print(q3)

# [n, d] = [4, 6]
kT = torch.tensor([[1, 1, 1, 1], 
                   [2, 2, 2, 2], 
                   [3, 3, 3, 3], 
                   [4, 4, 4, 4],
                   [5, 5, 5, 5],
                   [6, 6, 6, 6]])
v = torch.tensor([[1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1]])

print('kT @ v', kT @ v)
# q与(k.T @ v)的点积
result = torch.matmul(q3, kT @ v)
print('result', result)

在这里插入图片描述

此时 K T V K^TV KTV的结果是双向的, k 3 k_3 k3的输出矩阵中使用了 v 4 v_4 v4,这样双向注意力就可以顺畅的右乘得到 K T V K^TV KTV结果再与Q相乘,得到所有token的输出。

但因果模型的注意力是单向的, K T V K^TV KTV在计算的时候前面的K不能与后面的V相乘,所以只能一个一个算然后累加求和。

o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

这样的累加操作无法进行高效的矩阵乘法,虽然计算复杂度降低了,但实际运算的效率并不高。

Lightning Attention

到这里可以引出MiniMax-01 中所使用的Lightning Attention了,但其实这个注意力有两个版本,MiniMax-01中所提到的就是是Lightning Attention-2,那咱们先看看第一个版本做了什么。

Lightning Attention-1

源自:TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer

Lightning Attention-1针对于原始注意力取消了softmax,使用Swish激活函数代替。即先变成了
Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ⊙ M ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T\odot M)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)TM)V
然后还是先左乘计算,并没有解决线性注意力的根本问题,但是借鉴了flash attention中的硬件加速。

其前向和反向传播流程如下,就是将QKV切块,放到高速SRAM中去计算。虽然变快了,但此时的复杂度还是 O ( n 2 d ) O(n^2d) O(n2d)
在这里插入图片描述
在这里插入图片描述

Lightning Attention-2

源自:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

Lightning Attention-2解决了因果模型在计算单向注意力时,需要进行累加求和操作导致无法矩阵运算的情况,实现了单向注意力先计算右乘,成功将复杂度降为 O ( n d 2 ) O(nd^2) O(nd2)
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

再将这个累加求和公式拿过来,配合下图观察发现,之前的问题是每次计算 Q K T QK^T QKT都在整个序列上计算,这样每次都是所有序列的token互相注意到。那如果在序列这个维度拆分成小份,比如图中右侧先计算 k 1 k_1 k1 k 2 k_2 k2,然后用于 q 3 q_3 q3的计算就完全没有问题, k 4 k_4 k4后面的就不计算了。这样就既能矩阵运算,又能符合单向掩码。

公式中也可以发现,当前step之前的k和v是可以相乘的,比如 q 3 q_3 q3在计算时,可以将 k 1 T v 1 + k 2 T v 2 + k 3 T v 3 k_1^Tv_1+k_2^Tv_2+k_3^Tv_3 k1Tv1+k2Tv2+k3Tv3使用矩阵操作运算。所以Lightning Attention-2将大矩阵拆开,类似flash attention拆成多个block。
在这里插入图片描述
这些 block 不能拆分成 n 份,这样block的意义就没有了,for循环计算反而更慢。所以每个 block 中会有多个时间步的token。

此时这些 block 就可以分为两类,一类是块内(intra block),一类是块间(inter block)。块内代表当前块 q 的序列下标和 kv 序列下标相同,块间即不同。

块内在计算 q i q_i qi时直接矩阵右乘很容易算上 k i + 1 v i + 1 k_{i+1}v_{i+1} ki+1vi+1,所以块内使用传统的左乘并与掩码矩阵相乘。块间计算时就可以先右乘计算 K t V K^tV KtV,因为之前的kv是可以双向注意力的。然后将之前的kv结果缓存下来并更新,用于下一个step计算。

下图是Lightning Attention-2的结构图, λ \lambda λ是它的模型所使用的位置编码,忽略即可。
在这里插入图片描述
以下是前向传播和反向传播流程。
在这里插入图片描述
在这里插入图片描述
问题:M矩阵维度是[B, B],相当于每一个块代表了多个序列步n,在对角线位置是1,那在这个块内前面的q就可以注意到后面的kv了

解答:M矩阵维度虽然是[B, B],但只是这么切割,其内部值仍然是下三角。

备注

个人理解,若有不对请指出,谢谢。

您可能感兴趣的与本文相关的镜像

Yolo-v5

Yolo-v5

Yolo

YOLO(You Only Look Once)是一种流行的物体检测和图像分割模型,由华盛顿大学的Joseph Redmon 和Ali Farhadi 开发。 YOLO 于2015 年推出,因其高速和高精度而广受欢迎

### MiniMax-Text-01 开源代码解析与实现原理 MiniMax-Text-01 是 2025 年初由 MiniMax 发布的基础语言大模型,其核心创新在于采用了线性注意力架构,并实现了高达 400 万 token 的超长上下文能力。这一设计不仅提升了模型在处理长文本时的效率,还使其在性能上接近 GPT-4o,为 Agent 时代的到来奠定了技术基础[^1]。 #### 架构设计:线性注意力机制 MiniMax-Text-01 的核心架构基于线性注意力机制(Linear Attention),这种机制通过将传统的点积注意力转化为线性变换的形式,显著降低了计算复杂度。传统 Transformer 模型中,注意力机制的时间复杂度为 O(),而线性注意力将其优化为 O(n),从而支持更长的上下文长度和更高的推理效率[^3]。 以下是一个简化的线性注意力机制实现示例: ```python import torch import torch.nn as nn class LinearAttention(nn.Module): def __init__(self, dim, heads=8): super().__init__() self.heads = heads self.to_qkv = nn.Linear(dim, dim * 3, bias=False) self.to_out = nn.Linear(dim, dim) def forward(self, x): b, n, d = x.shape qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: t.reshape(b, n, self.heads, -1).transpose(1, 2), qkv) # 线性注意力计算 k = k.softmax(dim=-1) context = torch.einsum('bhnd,bhne->bhde', k, v) out = torch.einsum('bhde,bhnd->bhne', context, q) out = out.transpose(1, 2).reshape(b, n, -1) return self.to_out(out) ``` 该模块通过 `einsum` 实现了高效的张量运算,避免了传统注意力中的高维矩阵乘法操作,从而降低了内存消耗和计算延迟。 #### 训练策略:预训练与监督微调(SFT) MiniMax-Text-01 在训练过程中采用了多阶段的预训练和监督微调(Supervised Fine-Tuning)策略。预训练阶段主要依赖大规模语料库进行自监督学习,使模型具备广泛的通用语言理解能力。而在 SFT 阶段,模型进一步在特定任务数据集上进行优化,以提升其在实际应用中的表现[^3]。 此外,MiniMax-M1 系列模型中还引入了一种创新的强化学习算法 CISPO(Constrained Iterative Self-Play Optimization),用于提升模型在复杂推理任务中的表现。虽然该算法主要用于 MiniMax-M1,但其设计理念也对 MiniMax-Text-01 的训练策略产生了影响,尤其是在逻辑推理和上下文连贯性方面[^3]。 #### 上下文扩展:支持 400 万 token MiniMax-Text-01 最显著的特点之一是支持高达 400 万 token 的上下文长度。这得益于其架构优化和内存管理策略的改进。模型内部采用分块缓存机制(Chunked Cache),将历史 token 分批次存储并按需加载,从而避免一次性加载全部上下文带来的内存瓶颈。这种机制使得模型在处理长文档、对话系统或多轮推理任务时表现出色。 #### 开源与部署 MiniMax-Text-01 的开源版本可通过 HuggingFace 获取,用户可以直接下载并部署在本地环境中。HuggingFace 提供了完整的模型权重、Tokenizer 和推理脚本,方便开发者快速集成到自己的项目中。例如,使用 Transformers 库加载 MiniMax-Text-01 的方式如下: ```python from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "MiniMaxAI/MiniMax-Text-01" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) input_text = "请解释量子力学的基本原理" inputs = tokenizer(input_text, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` 上述代码展示了如何加载模型并进行简单的文本生成任务。由于其强大的上下文处理能力,MiniMax-Text-01 特别适用于需要长文本理解和生成的场景,如法律文书分析、学术论文撰写、自动化报告生成等[^1]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值