长序列建模一直是自然语言处理领域的难题,而 NSA 的出现,为解决这一难题提供了新的思路。
核心贡献
首次实现硬件对齐的稀疏注意力机制,显著提升长上下文建模效率。
技术突破点
- 创新方法:用动态分层稀疏策略替代传统全注意力机制
- 关键优势:在64k长度序列上,解码、前向传播和反向传播速度提升显著
- 基础原理:类似"给AI添加决策流程图",通过压缩和选择关键信息减少计算量
现实影响
- 短期应用:加速长文本处理,如代码生成和文档分析
- 长期潜力:推动通用AI发展,实现更高效的长上下文推理
- 风险警示:可能在某些任务上引入信息丢失,需进一步优化
延伸思考
- 该研究与马斯克的脑机接口是否有结合可能?
- 如果技术开源,开发者最应该关注哪个模块?
资源索引
原文精简翻译
1. 引言
随着下一代大语言模型对长上下文建模的需求日益增加,标准注意力机制的高计算成本成为了一个显著的挑战。稀疏注意力机制为提高效率提供了有前景的方向,同时保持模型的能力。我们提出了NSA(Natively trainable Sparse Attention),一种硬件对齐且可原生训练的稀疏注意力机制,通过算法创新和硬件优化实现高效的长上下文建模。
图1展示了NSA在64k长度序列处理中的显著计算加速,同时在通用基准、长上下文任务和推理评估中超越全注意力基线。
2. 重新思考稀疏注意力方法
现代稀疏注意力方法在减少Transformer模型的理论计算复杂度方面取得了显著进展。然而,大多数方法主要在推理阶段应用稀疏性,而保留了预训练的全注意力骨干,这可能引入架构偏差,限制其充分利用稀疏注意力优势的能力。
2.1 高效推理的错觉
尽管许多方法在注意力计算中实现了稀疏性,但由于以下两个挑战,许多方法未能实现相应的推理延迟减少:
- 阶段限制的稀疏性:例如H2O在自回归解码期间应用稀疏性,但在预填充阶段需要计算密集的预处理。
- 与高级注意力架构的不兼容性:一些稀疏注意力方法无法适应现代解码高效架构,如多查询注意力(MQA)和分组查询注意力(GQA)。
2.2 可训练稀疏性的神话
我们分析推理方法的两个关键见解促使我们追求原生可训练的稀疏注意力:
- 性能下降:事后应用稀疏性迫使模型偏离其预训练的优化轨迹。
- 训练效率需求:高效处理长序列训练对于现代LLM开发至关重要。
3. 方法论
我们的技术方法涵盖算法设计和内核优化。我们首先介绍方法背景,然后介绍NSA的整体框架,接着是其关键算法组件,最后是我们硬件优化的内核设计。
3.1 背景
注意力机制广泛用于语言建模,其中每个查询标记计算与所有先前键的相关性分数,以生成值的加权和。
3.2 整体框架
为了利用自然稀疏模式的潜力,我们提出用更紧凑和信息密集的表示键值对替换原始键值对。
图2展示了NSA的架构框架,通过三个并行注意力分支处理输入序列:压缩注意力、选择注意力和滑动窗口注意力。
3.3 算法设计
我们介绍了重新映射策略和的设计:标记压缩、标记选择和滑动窗口。
3.3.1 标记压缩
通过将键或值的顺序块聚合为块级表示,我们获得压缩的键和值,捕捉整个块的信息。
3.3.2 标记选择
仅使用压缩键和值可能会丢失重要的细粒度信息,促使我们选择性地保留单个键和值。
3.3.3 滑动窗口
在注意力机制中,局部模式通常适应得更快,并可能主导学习过程,防止模型从压缩和选择标记中有效学习。为了解决这个问题,我们引入了一个专门的滑动窗口分支,明确处理局部上下文。
3.4 内核设计
为了在训练和预填充期间实现FlashAttention级别的加速,我们在Triton上实现了硬件对齐的稀疏注意力内核。
图3展示了NSA的内核设计,通过GQA组加载查询(网格循环),获取相应的稀疏KV块(内部循环),并在SRAM上执行注意力计算。
4. 实验
我们通过三个视角评估NSA:通用基准性能、长上下文基准性能和链式思维推理性能,与全注意力基线和最先进的稀疏注意力方法进行比较。
4.1 预训练设置
我们的实验采用结合分组查询注意力(GQA)和专家混合(MoE)的骨干,总参数为27B,活动参数为3B。
图4展示了NSA和全注意力基线的预训练损失对比,NSA始终表现更好。
4.2 基线方法
除了与全注意力进行比较外,我们还评估了几种最先进的推理阶段稀疏注意力方法:H2O、infLLM、Quest和Exact-Top。
4.3 性能比较
模型 | SQA | MQA | Synthetic | Code | Avg. |
---|---|---|---|---|---|
H2O | 0.428 | 0.429 | 0.308 | 0.112 | 0.101 |
InfLLM | 0.474 | 0.517 | 0.356 | 0.306 | 0.250 |
Quest | 0.495 | 0.561 | 0.365 | 0.295 | 0.245 |
Exact-Top | 0.502 | 0.605 | 0.397 | 0.321 | 0.288 |
Full Attn | 0.512 | 0.623 | 0.409 | 0.350 | 0.305 |
NSA | 0.503 | 0.624 | 0.432 | 0.437 | 0.356 |
表2展示了NSA在LongBench上的性能比较,包括单文档QA、多文档QA、合成和代码任务类别。NSA在大多数基准上表现优于全注意力。
5. 效率分析
我们在8-GPU A100系统上评估NSA与全注意力的计算效率。
5.1 训练速度
我们比较了基于Triton的NSA注意力和全注意力与Triton-based FlashAttention-2的实现。
图6展示了NSA内核在64k上下文长度上实现了9.0倍的前向和6.0倍的反向加速。
5.2 解码速度
注意力的解码速度主要由内存访问瓶颈决定,这与KV缓存加载量密切相关。
上下文长度 | 8192 | 16384 | 32768 | 65536 |
---|---|---|---|---|
全注意力 | 8192 | 16384 | 32768 | 65536 |
NSA | 2048 | 2560 | 3584 | 5632 |
预期加速 | 4× | 6.4× | 9.1× | 11.6× |
表4展示了解码期间每次注意力操作的内存访问量(以等效标记数表示)。由于解码的低算术强度和内存限制,预期加速与内存访问量近似线性。
6. 讨论
在本节中,我们反思NSA的开发过程,并讨论从不同稀疏注意力策略探索中获得的关键见解。
6.1 替代标记选择策略的挑战
在设计NSA之前,我们探索了将现有稀疏注意力方法适应训练阶段。然而,这些尝试遇到了各种挑战,促使我们设计不同的稀疏注意力架构。
图7展示了在3B参数模型上,NSA与全注意力和不同标记选择策略的训练损失对比。NSA表现更好。
6.2 可视化
为了探索Transformer注意力分布的潜在模式并为我们的设计寻求灵感,我们可视化了预训练的27B全注意力模型的注意力图。
图8展示了注意力图的可视化,浅色区域表示较高的注意力值。如图所示,注意力分数呈现出块状聚类分布。
7. 相关工作
我们回顾了通过稀疏注意力提高注意力计算效率的现有方法。这些方法可以根据其核心策略大致分为三类:固定稀疏模式、动态标记修剪和查询感知选择。
7.1 固定稀疏模式
滑动窗口是一种常用方法,允许查询仅在固定窗口内计算注意力。
7.2 动态标记修剪
H2O实现了在解码期间减少KV缓存内存使用的自适应方法。
7.3 查询感知选择
Quest采用块状选择策略,其中每个块的重要性通过查询与键块的坐标最小最大值乘积来估计。
8. 结论
我们提出了NSA,一种硬件对齐的稀疏注意力架构,用于高效的长上下文建模。通过将分层标记压缩与块状标记选择结合在可训练架构中,我们的架构在保持全注意力性能的同时实现了加速训练和推理。
附录A AIME结果示例
提示:
我们的结果:
基线结果:
专业术语词典
1. 稀疏注意力(Sparse Attention)
- 解释:一种减少注意力机制计算量的方法,通过只计算部分关键的查询-键对,而不是所有可能的对。类似于“选择性阅读”,只关注最重要的部分。
2. 全注意力(Full Attention)
- 解释:传统的注意力机制,计算每个查询与所有键的相关性。类似于“逐字阅读”,计算量大但信息全面。
3. 硬件对齐(Hardware-Aligned)
- 解释:算法设计时考虑硬件的特性(如内存访问模式、计算单元利用率),以最大化硬件性能。类似于“为特定机器定制工具”。
4. 解码(Decoding)
- 解释:在生成文本时,模型逐步预测下一个词的过程。类似于“逐词写作”。
5. 前向传播(Forward Propagation)
- 解释:神经网络中,输入数据通过各层计算得到输出的过程。类似于“信息从输入到输出的流动”。
6. 反向传播(Backward Propagation)
- 解释:神经网络训练时,根据损失函数计算梯度并更新模型参数的过程。类似于“根据错误调整学习方法”。
7. KV缓存(KV Cache)
- 解释:在解码过程中,存储键(Key)和值(Value)的内存区域,用于加速注意力计算。类似于“临时存储重要信息的笔记本”。
8. 算术强度(Arithmetic Intensity)
- 解释:计算操作数与内存访问数的比值,用于衡量算法在硬件上的效率。类似于“计算量与数据搬运量的平衡”。
9. 分组查询注意力(Grouped-Query Attention, GQA)
- 解释:一种注意力机制变体,多个查询共享相同的键和值,以减少内存访问和计算量。类似于“多人共享同一份参考资料”。
10. 专家混合模型(Mixture of Experts, MoE)
- 解释:一种模型架构,包含多个专家网络,每个输入由部分专家处理。类似于“分诊系统,不同问题由不同专家解决”。
11. 滑动窗口(Sliding Window)
- 解释:一种局部注意力机制,只计算查询附近一定范围内的键值对。类似于“只关注当前段落的内容”。
12. 注意力分数(Attention Score)
- 解释:衡量查询与键之间相关性的数值,用于加权值。类似于“关键词的重要性评分”。
13. 链式思维推理(Chain-of-Thought Reasoning)
- 解释:一种推理方法,模型逐步生成中间步骤以解决问题。类似于“逐步推导数学题的解法”。
14. 预训练(Pretraining)
- 解释:在大规模数据上训练模型,使其学习通用特征。类似于“学习基础知识”。
15. 微调(Fine-Tuning)
- 解释:在特定任务数据上进一步训练预训练模型,以适应具体任务。类似于“针对特定领域进行专项训练”。
16. FlashAttention
- 解释:一种高效的注意力计算实现,通过优化内存访问和计算顺序来加速注意力机制。类似于“优化后的快速阅读方法”。
17. Triton
- 解释:一种用于编写高效GPU内核的编程语言和编译器,常用于深度学习模型的加速。类似于“为GPU定制的编程工具”。
18. 内存访问瓶颈(Memory Access Bottleneck)
- 解释:计算任务中,内存访问速度限制了整体性能的现象。类似于“搬运数据的速度跟不上计算速度”。
19. 块状选择(Blockwise Selection)
- 解释:将键和值序列分成连续块,并选择最重要的块进行计算。类似于“按章节选择重要内容”。
20. 注意力图(Attention Map)
- 解释:可视化注意力分数的矩阵,显示查询与键之间的相关性。类似于“关键词与上下文的关联图”。