DeepSeekNSA省力的同时还能提速!
一、引言:为什么要关注稀疏注意力?
在 Transformer 闪耀 NLP 舞台的时代,模型对长文本的处理仍是痛点:
- 计算瓶颈:全注意力机制需对每对 Token 计算相似度,复杂度为 O(n2)O(n^2)O(n2),长文本时计算量呈平方增长。
- 显存压力:在 GPU 上训练或推理长序列时,显存消耗剧增,甚至溢出。
- 应用场景:文档检索、代码审查、学术论文阅读等场合,需要模型理解成千上万的 Token。
2025 年 ACL’25 最佳论文——《Native Sparse Attention (NSA)》,提出了一种既大幅降低计算成本又提升模型性能的稀疏注意力方案,开启了长上下文处理新纪元。
二、核心原理拆解
2.1 三条稀疏分支如何协作?
-
全局压缩注意力
- 原理:将长度 nnn 的序列通过平均/最大池化等操作压缩到长度 mmm,再在压缩后序列上执行全注意力。
- 计算复杂度由 O(n2)O(n^2)O(n2) 降至 O(m2)O(m^2)O(m2):前者表示随着序列长度平方倍增长,后者仅随压缩后长度平方增长。
- 示例:当 n=65536n=65536n=65536 且 m=1024m=1024m=1024 时,计算量从 4.3×1094.3\times10^94.3×109 降至 1.0×1061.0\times10^61.0×106(减少约四个数量级)。
- 比喻:快速扫读目录,掌握整体框架。
-
选择性注意力(Selective Sparsity)
- 原理:为每个 Token 计算重要性分数(基于键-值内积或门控网络),只保留 Top-kkk 个最重要的 Token 进行注意力计算,复杂度约 O(nk)O(nk)O(nk).
- 动态调整:kkk 可根据任务难度自动伸缩,平衡覆盖率与效率。
- 比喻:用书签标记重点段落,只精读这些片段。
-
滑动窗口注意力(Sliding Window)
- 原理:以步长 sss 和窗口宽度 www 划分序列,在每个窗口内执行全注意力,局部复杂度 O(ws)O(ws)O(ws)。
- 参数示例:常用 w=512,s=256w=512, s=256w=512,s=256,可根据 GPU 显存灵活设置。
- 比喻:用放大镜逐段扫描,确保连续上下文不丢失。
三条分支并行计算后,通过可学习权重 (α,β,γ)(\alpha,\beta,\gamma)(α,β,γ) 加权融合:
NSA(X)=α Aglobal(X)+β Aselective(X)+γ Awindow(X) \mathrm{NSA}(X)=\alpha\,A_{global}(X)+\beta\,A_{selective}(X)+\gamma\,A_{window}(X) NSA(X)=αAglobal(X)+βAselective(X)+γAwindow(X)
2.2 补充知识点
-
对比常见稀疏策略:
- 局部注意力(Local):仅滑窗,缺乏全局感知;
- 块稀疏(Block Sparse):固定块交叉,局部全局互限;
- BigBird:随机+局部+全局,随机性带来不稳定;
- NSA:全局压缩+选择性+滑窗,系统性覆盖全局、重点、局部,兼顾效率与稳定。
-
复杂度公式:NSA 总体复杂度为 O(m2+nk+ws)O(m^2 + nk + ws)O(m2+nk+ws),当 m,k,w,s≪nm,k,w,s\ll nm,k,w,s≪n 时,远低于 O(n2)O(n^2)O(n2)。
-
可视化示例:专栏提供的注意力热力图对比,直观展示 NSA 如何同时捕捉全局脉络与局部细节。
2.3 伪代码示例
# NSA 三分支核心伪码(示例)
import torch.nn.functional as F
def global_compress_attention(X, m):
# X: [B, L, D]
X_pooled = F.adaptive_avg_pool1d(X.transpose(1,2), m).transpose(1,2)
return full_attention(X_pooled)
def selective_attention(X, k):
# 计算重要性分数并取 top-k
scores = (X @ X.transpose(-1,-2)).mean(dim=1) # 简化示例
idx = scores.topk(k, dim=-1).indices
return full_attention(X[idx], X)
def sliding_window_attention(X, w, s):
# 分段滑窗
outs = []
for i in range(0, X.size(1)-w+1, s):
outs.append(full_attention(X[:, i:i+w], X[:, i:i+w]))
return torch.cat(outs, dim=1)
# 融合输出
def NSA_attention(X, config):
g = global_compress_attention(X, config.m)
s = selective_attention(X, config.k)
l = sliding_window_attention(X, config.w, config.s)
return config.alpha*g + config.beta*s + config.gamma*l
此伪代码仅用于流程示意,不可直接用于生产环境。
三、快速记忆核心知识点
- 全局压缩 = 粗读目录:快速建立整体框架;
- 选择性 = 精读标记:专注最重要内容;
- 滑动窗口 = 放大镜:细致扫描局部上下文。
三步结合,让模型既能“见微知著”,又能“略窥全豹”。
四、量化评估
| 评估类型 | 序列长度 | 全注意力延迟 | NSA 延迟 | 加速比 | 备注 |
|---|---|---|---|---|---|
| 推理(Forward) | 64k | 24s | 2.07s | ~11.6× | 无性能损耗 |
| 训练(Backward) | 32k | 18s | 3s | ~6× | 包含自定义反向算子 |
| 显存峰值 | 64k | 100% | ~65% | — | 节省显存 |
| 通用基准(9项平均) | — | — | +0.022 | — | 性能略升 |
| 数学推理(GSM8K) | — | — | +0.034 | — | 推理能力增强 |
| 长文本任务(LongBench) | — | — | +0.032 | — | 精度提升 |
解读:NSA 在长文本场景下实现10×以上加速,同时因聚焦重要信息,性能亦得到提升。
2029

被折叠的 条评论
为什么被折叠?



