深入解析 DeepSeek 最新论文提出的 NSA 注意力机制
引言
Transformer 架构自从 2017 年问世以来,在自然语言处理和计算机视觉等领域取得了革命性进展。然而,标准Transformer的自注意力机制(self-attention)在序列长度为 n 时计算和内存开销为 O ( n 2 ) O(n^2) O(n2),这使得在处理长序列时非常耗费资源 ([2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness)。为了解决长序列建模的效率问题,近年来众多研究提出了近似或稀疏的注意力机制,例如 Longformer、Performer、Reformer、BigBird 等等,它们通过预定义模式、低秩近似或随机映射等方式降低计算复杂度 (Understanding BigBird’s Block Sparse Attention) 。然而,大多数此类方法使用预先设定的稀疏模式或启发式规则,这难以充分适应注意力在不同输入中的动态分布 ([2410.13276] SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs)。这导致在追求效率的同时,模型性能往往有所损失,或者需要复杂的辅助训练技巧。
近期,研究团队 DeepSeek 发布了最新论文,提出了一种创新的注意力机制 NSA。NSA 全称为“原生稀疏注意力”(Native Sparse Attention),它强调从训练过程开始就原生地学习稀疏模式,而非事后裁剪或固定模式 (DeepSeek releases the paper “Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention Mechanism” - PANews)。该方法在保证模型性能接近甚至优于全量注意力的同时,大幅提升了长序列处理的效率 。NSA 的核心思想是结合算法创新与硬件优化,通过层次化的稀疏机制在保持全局语义和局部细节的同时减少不必要的计算开销 。根据 DeepSeek 团队的介绍,NSA 具有以下三大创新特点 :
- 动态层次化稀疏策略:采用粗粒度的Token压缩(coarse-grained token compression)捕获全局上下文,以及细粒度的Token选择(fine-grained token selection)保留局部重要信息,从而既保证全局语义不丢失,又兼顾关键细节。
- 面向硬件的平衡设计:在算法设计上平衡计算与内存访存(提高算术强度),并利用现代硬件(如 GPU Tensor Core)的特性进行优化,大幅加速长序列注意力的计算。
- 端到端可训练的稀疏模式:支持注意力稀疏模式的端到端训练,无需预训练完再裁剪,从预训练阶段模型就学习最优的稀疏结构,减少预训练开销且不牺牲模型性能。
实验结果显示,基于 NSA 预训练的模型在通用基准、长文本任务和推理任务上都达到或超过了全注意力模型的表现。同时在最长可达 64k 的序列上,NSA 在解码、前向传播和反向传播各阶段都实现了显著的加速 。本文将对 NSA 机制的数学原理、与传统Transformer注意力的对比、在不同任务中的应用表现、实现细节以及其优势与局限进行深入解析,力求为专业读者提供全面而透彻的理解。
接下来,我们将首先回顾Transformer自注意力机制的基本公式,然后引出NSA的数学定义和原理;之后比较NSA与常用注意力机制(如自注意力、FlashAttention、Linformer等)的异同;然后讨论NSA在NLP、CV等任务中的实验表现;接着介绍NSA的实现细节,包括伪代码示例和性能优化方法;最后总结NSA的优势、可能的局限性并展望未来发展方向。
1. 注意力机制背景:从全量自注意力到稀疏注意力
在介绍NSA机制之前,我们先简要回顾Transformer的标准自注意力公式,以及长序列注意力效率问题和以往的一些解决方案。
1.1 标准自注意力机制回顾
Transformer自注意力允许序列中每个元素(称为“查询”query)对其他元素(作为“键”key和“值”value)计算相关性并聚合信息。对于输入序列长度为 n n n,隐藏维度为 d d d,令 Q , K , V ∈ R n × d Q, K, V \in \mathbb{R}^{n \times d} Q,K,V∈Rn×d 分别表示查询、键、值的矩阵表示(通常由输入经过不同线性映射得到)。那么第 i 个位置的注意力输出可以表示为所有值的加权和:
- 注意力权重:首先计算查询与键的点积相似度,并通过softmax获得归一化权重:
α i j = exp ( ( Q K T ) i j / d ) ∑ l = 1 n exp ( ( Q K T ) i l / d ) , \alpha_{i j} \;=\; \frac{\exp\left((Q K^T)_{ij} / \sqrt{d}\right)}{\sum_{l=1}^{n} \exp\left((Q K^T)_{il} / \sqrt{d}\right)} \,, αij=∑l=1nexp((QKT)il/d)exp((QKT)ij/d),
其中 ( Q K T ) i j = q i ⋅ k j (Q K^T)_{ij} = q_i \cdot k_j (QKT)ij=qi⋅kj 表示第 i 个查询和第 j 个键的内积相似度, d \sqrt{d} d 为缩放因子。
- 输出计算:利用注意力权重对值进行加权求和,得到第 i 个位置的输出:
o i = ∑ j = 1 n α i j v j . o_i \;=\; \sum_{j=1}^{n} \alpha_{ij}\, v_j \,. oi=j=1∑nαijvj.
上述计算可以向量化地表示为:
Attention ( Q , K , V ) = softmax ( Q K T d ) V , \text{Attention}(Q, K, V) \;=\; \text{softmax}\!\Big(\frac{Q K^T}{\sqrt{d}}\Big)\, V\,, Attention(Q,K,V)=softmax(dQKT)V,
这是Transformer的基础操作 。自注意力同时考虑所有位置两两之间的交互,因此能够捕获全局依赖,但其计算复杂度和内存消耗随着序列长度平方级增长。当 n n n 很大时(例如上万甚至数万token),这成为主要的性能瓶颈 。
1.2 长序列注意力的挑战与现有方案
为了承载更长的上下文,研究者提出了多种改进自注意力效率的方法:
-
启发式稀疏模式:一些方法在注意力矩阵中采用固定的稀疏模式,使每个token只与一部分tokens交互。例如Longformer使用局部滑动窗口+全局token的模式,限制大部分注意力在局部窗口内,仅少数全局token互看,从而将复杂度降为近线性 。BigBird进一步结合了局部块、随机块和若干全局token的稀疏结构,实现在4096长度下处理长文档问答等任务达到SOTA,同时理论上证明了这种稀疏模式的完备性 。这些方法无需训练注意力模式,但稀疏结构是人为设计,可能无法对不同数据分配的注意力做到自适应最优 。
-
低秩近似:Linformer 提出注意力矩阵可以被低秩矩阵近似,具体做法是对 K K K和 V V V在序列长度维度上乘以一个低秩投影矩阵,将长度 n n n降到 k ≪ n k \ll n k≪n,从而将注意力计算降为 O ( n ⋅ k ) O(n \cdot k) O(n⋅k) ([2006.04768] Linformer: Self-Attention with Linear Complexity)。这使时间和空间复杂度都线性缩放于 n n n。Linformer在一些任务上达到与全注意力相当的效果,但低秩近似对于不同输入的适应性有限,如果真实注意力矩阵并非低秩,则可能损失精度。
-
随机特征近似:Performer等方法用随机特征映射将softmax注意力转化为内积形式,以线性复杂度近似计算注意力,无需显式计算完整 Q K T QK^T QKT 。这种方法在理论上提供无偏近似,但实践中精度取决于随机特征维度,可能需要较大近似维度才能接近全注意力效果。
-
显存优化算法:值得注意的是,有些工作从实现层面优化注意力的内存访问模式,例如FlashAttention 。FlashAttention并未改变注意力结果本身,它通过块状tiling策略,在GPU上分块计算软max和加权和,最大化数据在高速缓存(SRAM)中的重用,减少反复访存HBM(显存)的次数 。这种IO优化使得在相同计算复杂度下,注意力计算的实际速度大幅提升。例如,FlashAttention在GPT-2上达到3倍速度提升,在长序列任务上实现了最高可达16k甚至64k长度的Transformer高效训练 。需要注意FlashAttention依然计算完整的注意力矩阵,只是在实现上更高效;它也可以配合稀疏模式(如块稀疏)一起使用 。
上述方法各有优劣:固定稀疏模式和低秩近似可能牺牲模型对不同场景的适应能力,而完全保真如FlashAttention则仍无法从算法上突破 O ( n 2 ) O(n^2) O(n2)瓶颈。有没有可能设计一种注意力机制,让模型自己学习哪里该稀疏、哪里需精细关注,以在不损失性能的情况下取得与手工优化相媲美的效率?DeepSeek 的 NSA 便是为此而提出。
2. NSA机制的数学原理
NSA(原生稀疏注意力)尝试在不借助预先固定模式的情况下,让注意力机制本身在训练中学会稀疏化策略,从而既保证对重要信息不漏掉,又减少无用的计算。其核心思想是:为每个查询动态地重组一个更紧凑的键值集合,仅包含对该查询最有用的全局和局部信息,然后在该集合上计算注意力。这一过程通过粗粒度压缩和细粒度选择两步实现,并辅以一个滑动窗口分支专门处理局部上下文,最后用一个门控机制融合这些分支的结果。下面我们详细介绍NSA的各组成部分及其公式。
2.1 动态重组注意力:总体框架
对于每个查询向量 q i q_i qi,NSA不再直接与整个序列的键 K K K、值 V V V 计算注意力,而是构造一个远小于 n n n 的子集 K ~ i , V ~ i \tilde K_i, \tilde V_i K~i,V~i 来近似代表原始的 K , V K, V K,V。我们希望 K ~ i , V ~ i \tilde K_i, \tilde V_i K~i,V~i既包含全局上下文的抽象表示,又保留与查询相关的关键细节。形式化地,NSA对第 i i i 个查询的注意力输出定义为:
o i = Attention ( q i , K ~ i , V ~ i ) , o_i \;=\; \text{Attention}\!\Big(q_i,\; \tilde K_i,\; \tilde V_i\Big)\,, oi=Attention(qi,K~i,V~i),
其中 K ~ i , V ~ i \tilde K_i, \tilde V_i K~i,V~i 是从原始上下文 ( K , V ) (K, V) (K,V)为查询 q i q_i q<