图片速览 DeepSeek NSA Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention

Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention

PAPERCODE
https://arxiv.org/pdf/2502.11089v1https://github.com/fla-org/native-sparse-attention

摘要

  • 长上下文建模对下一代语言模型至关重要,但标准注意力机制的高计算成本带来了显著的计算挑战。稀疏注意力为提高效率同时保持模型能力提供了一个有前景的方向。本文提出了NSA,一种原生可训练的稀疏注意力机制,它将算法创新与硬件对齐的优化相结合,以实现高效的长上下文建模。NSA采用动态分层稀疏策略,将粗粒度的标记压缩与精细粒度的标记选择相结合,以保留全局上下文感知和局部精度。该方法通过两个关键创新推动了稀疏注意力设计:(1) 通过算术强度平衡的算法设计,在现代硬件的实现优化下,实现了显著的加速。(2) 使得端到端训练成为可能,减少了预训练计算,同时不牺牲模型性能。如图1所示,实验表明,使用NSA预训练的模型在通用基准、长上下文任务和基于指令的推理方面,保持或超过了全注意力模型。同时,NSA在64k长度序列上,在解码、前向传播和反向传播中相较于全注意力实现了显著加速,验证了其在整个模型生命周期中的效率。

在这里插入图片描述

  • NSA架构概述
    • 左图:该框架通过三个并行的注意力分支处理输入序列:对于给定的查询,前面的键和值被Top-n Selection处理为粗粒度模式的压缩注意力,重要标记块的选定注意力,以及局部上下文的滑动注意力。
    • 右图:每个分支产生的不同注意力模式的可视化。绿色区域表示需要计算注意力得分的区域,而白色区域表示可以跳过的区域。

NSA Native Sparse Attention

  • 为了利用自然稀疏模式的注意力潜力,替换原始键值对k和v,用具有更紧凑和信息密集的集合表示键值对。

在这里插入图片描述

  • 然后定义优化的注意力输出如下:

在这里插入图片描述

  • 其中 K t ~ \tilde{K_t} Kt~ V t ~ \tilde{V_t} Vt~ 是根据当前查询 q𝑡 和上下文记忆 k ~ : t \tilde{k}_{:t} k~:t, v ~ : t \tilde{v}_{:t} v~:t动态构建的。
  • 我们可以设计各种映射策略,以获得不同类别的 K t C ~ \tilde{K^C_t} KtC~ V t C ~ \tilde{V^C_t} VtC~,并将它们组合如下:

在这里插入图片描述

  • NSA 有三种映射策略 C = { c m p , s l c , w i n } C = \{cmp, slc, win\} C={cmp,slc,win},分别表示键值的压缩、选择和滑动窗口。 g t c ∈ [ 0 , 1 ] g_{t}^c ∈ [0,1] gtc[0,1] 是对应策略 𝑐 的门控分数,通过多层感知机(MLP)和 sigmoid 激活函数从输入特征中推导出来。
  • N t N_t Nt 表示重新映射后的键值对的总数,我们通过确保 𝑁𝑡 ≪ 𝑡 来维持较高的稀疏性比率:
    在这里插入图片描述

Token Compression

  • 通过MLP(公式中的 ψ ψ ψ)进行分块压缩得到结果 K ~ t c m p ∈ R d k × ⌊ t − l d ⌋ \tilde K_t^{cmp}\in R^{d_k\times \lfloor \frac{t-l}{d} \rfloor} K~tcmpRdk×dtl

在这里插入图片描述

Token Selection

  • 仅使用压缩后的键值对可能会丢失重要的细粒度信息,这促使我们有选择地保留个别键值对。下面我们将描述一种高效的令牌选择机制,该机制能够识别并保留最相关的令牌,同时保持低计算开销。

  • 块级选择 选择策略在空间连续的块中处理键和值序列,主要受到两个关键因素的驱动:硬件效率考虑和注意力得分的固有分布模式。块级选择对于在现代 GPU 上实现高效计算至关重要。这是因为与基于随机索引的读取相比,现代 GPU 架构对连续块访问具有显著更高的吞吐量。此外,块级计算能够实现张量核心的最佳利用。这个架构特性使得块级内存访问和计算成为高性能注意力实现的基本原则,例如 FlashAttention 的基于块的设计。块级选择遵循注意力得分的固有分布模式。先前的研究(Jiang 等人,2024)表明,注意力得分通常表现出空间连续性,暗示相邻的键通常具有相似的重要性水平。文章在第6.2节中的可视化也展示了这种空间连续的模式。

  • 为了实现块级选择,我们首先将键和值序列划分为选择块。为了识别对注意力计算最重要的块,我们需要为每个块分配重要性得分。下面我们将介绍我们计算这些块级重要性得分的方法。

  • Importance Score Computation. 计算块的重要性得分可能会引入显著的开销。幸运的是,压缩令牌的注意力计算会生成中间注意力得分,我们可以利用这些得分来推导选择块的重要性得分,公式如下:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • Top-𝑛 Block Selection 在获得选择块重要性得分后,保留按块重要性得分排名的前稀疏块内的令牌,公式如下:
    在这里插入图片描述

Sliding Window

  • 在注意力机制中,局部模式通常适应得更快,并可能主导学习过程,从而阻止模型有效地从压缩和选择令牌中学习。为了解决这个问题,我们引入了一个专门的滑动窗口分支,显式地处理局部上下文,使得其他分支(压缩和选择)能够专注于学习各自的特征,而不被局部模式所干扰。具体而言,我们在一个窗口 𝑤 中保持最近的令牌 K ~ t w i n = k t − w : t \color{red}\tilde K_t^{win}=k_{t-w:t} K~twin=ktw:t V ~ t w i n = v t − w : t \color{red}\tilde V_t^{win}=v_{t-w:t} V~twin=vtw:t并将不同信息源(压缩令牌、选择令牌、滑动窗口)的注意力计算隔离到不同的分支中。然后,这些分支的输出通过一个学习的门控机制进行聚合。为了进一步防止跨注意力分支的快捷学习,并且保持最小的计算开销,我们为三个分支提供独立的键和值。这个架构设计通过防止局部模式和长程模式识别之间的梯度干扰,从而实现了稳定的学习,同时引入了最小的开销。

最终结果

  • 在获得所有三类键和值( K ~ t c m p \tilde{K}_t^{cmp} K~tcmp, V ~ t c m p \tilde{V}_t^{cmp} V~tcmp; K ~ t s l c \tilde{K}_t^{slc} K~tslc, V ~ t s l c \tilde{V}_t^{slc} V~tslc K ~ t w i n \tilde{K}_t^{win} K~twin, V ~ t w i n \tilde{V}_t^{win} V~twin;)之后,我们按照公式 (5) 计算最终的注意力输出。结合上文描述的压缩、选择和滑动窗口机制,这构成了 NSA 的完整算法框架。
    在这里插入图片描述
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值