DeepSeekNSA省力的同时还能提速!

DeepSeekNSA省力的同时还能提速!

一、引言:为什么要关注稀疏注意力?

在 Transformer 闪耀 NLP 舞台的时代,模型对长文本的处理仍是痛点

  • 计算瓶颈:全注意力机制需对每对 Token 计算相似度,复杂度为 O(n2)O(n^2)O(n2),长文本时计算量呈平方增长。
  • 显存压力:在 GPU 上训练或推理长序列时,显存消耗剧增,甚至溢出。
  • 应用场景:文档检索、代码审查、学术论文阅读等场合,需要模型理解成千上万的 Token。

2025 年 ACL’25 最佳论文——《Native Sparse Attention (NSA)》,提出了一种既大幅降低计算成本提升模型性能的稀疏注意力方案,开启了长上下文处理新纪元。

二、核心原理拆解

2.1 三条稀疏分支如何协作?
  1. 全局压缩注意力

    • 原理:将长度 nnn 的序列通过平均/最大池化等操作压缩到长度 mmm,再在压缩后序列上执行全注意力。
    • 计算复杂度由 O(n2)O(n^2)O(n2) 降至 O(m2)O(m^2)O(m2):前者表示随着序列长度平方倍增长,后者仅随压缩后长度平方增长。
    • 示例:当 n=65536n=65536n=65536m=1024m=1024m=1024 时,计算量从 4.3×1094.3\times10^94.3×109 降至 1.0×1061.0\times10^61.0×106(减少约四个数量级)。
    • 比喻:快速扫读目录,掌握整体框架。
  2. 选择性注意力(Selective Sparsity)

    • 原理:为每个 Token 计算重要性分数(基于键-值内积或门控网络),只保留 Top-kkk 个最重要的 Token 进行注意力计算,复杂度约 O(nk)O(nk)O(nk).
    • 动态调整kkk 可根据任务难度自动伸缩,平衡覆盖率与效率。
    • 比喻:用书签标记重点段落,只精读这些片段。
  3. 滑动窗口注意力(Sliding Window)

    • 原理:以步长 sss 和窗口宽度 www 划分序列,在每个窗口内执行全注意力,局部复杂度 O(ws)O(ws)O(ws)
    • 参数示例:常用 w=512,s=256w=512, s=256w=512,s=256,可根据 GPU 显存灵活设置。
    • 比喻:用放大镜逐段扫描,确保连续上下文不丢失。

三条分支并行计算后,通过可学习权重 (α,β,γ)(\alpha,\beta,\gamma)(α,β,γ) 加权融合:

NSA(X)=α Aglobal(X)+β Aselective(X)+γ Awindow(X) \mathrm{NSA}(X)=\alpha\,A_{global}(X)+\beta\,A_{selective}(X)+\gamma\,A_{window}(X) NSA(X)=αAglobal(X)+βAselective(X)+γAwindow(X)

2.2 补充知识点
  • 对比常见稀疏策略

    • 局部注意力(Local):仅滑窗,缺乏全局感知;
    • 块稀疏(Block Sparse):固定块交叉,局部全局互限;
    • BigBird:随机+局部+全局,随机性带来不稳定;
    • NSA:全局压缩+选择性+滑窗,系统性覆盖全局、重点、局部,兼顾效率与稳定。
  • 复杂度公式:NSA 总体复杂度为 O(m2+nk+ws)O(m^2 + nk + ws)O(m2+nk+ws),当 m,k,w,s≪nm,k,w,s\ll nm,k,w,sn 时,远低于 O(n2)O(n^2)O(n2)

  • 可视化示例:专栏提供的注意力热力图对比,直观展示 NSA 如何同时捕捉全局脉络与局部细节。

2.3 伪代码示例
# NSA 三分支核心伪码(示例)
import torch.nn.functional as F

def global_compress_attention(X, m):
    # X: [B, L, D]
    X_pooled = F.adaptive_avg_pool1d(X.transpose(1,2), m).transpose(1,2)
    return full_attention(X_pooled)

def selective_attention(X, k):
    # 计算重要性分数并取 top-k
    scores = (X @ X.transpose(-1,-2)).mean(dim=1)  # 简化示例
    idx = scores.topk(k, dim=-1).indices
    return full_attention(X[idx], X)

def sliding_window_attention(X, w, s):
    # 分段滑窗
    outs = []
    for i in range(0, X.size(1)-w+1, s):
        outs.append(full_attention(X[:, i:i+w], X[:, i:i+w]))
    return torch.cat(outs, dim=1)

# 融合输出
 def NSA_attention(X, config):
    g = global_compress_attention(X, config.m)
    s = selective_attention(X, config.k)
    l = sliding_window_attention(X, config.w, config.s)
    return config.alpha*g + config.beta*s + config.gamma*l

此伪代码仅用于流程示意,不可直接用于生产环境。


三、快速记忆核心知识点

  • 全局压缩 = 粗读目录:快速建立整体框架;
  • 选择性 = 精读标记:专注最重要内容;
  • 滑动窗口 = 放大镜:细致扫描局部上下文。

三步结合,让模型既能“见微知著”,又能“略窥全豹”。


四、量化评估

评估类型序列长度全注意力延迟NSA 延迟加速比备注
推理(Forward)64k24s2.07s~11.6×无性能损耗
训练(Backward)32k18s3s~6×包含自定义反向算子
显存峰值64k100%~65%节省显存
通用基准(9项平均)+0.022性能略升
数学推理(GSM8K)+0.034推理能力增强
长文本任务(LongBench)+0.032精度提升

解读:NSA 在长文本场景下实现10×以上加速,同时因聚焦重要信息,性能亦得到提升。



评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值