【Attention优化重大突破!显存减半效率倍增,大模型长文本处理迎来新时代】

一,前言

从Transformer 17年被提出来, Attention 层的相关的工作从未间断,参考邱锡鹏教授21年关于Transformer的综述里的整理如下:

更耳熟能详一些的如:MQA,GQA ,DeepSeek的MLA,阶跃星辰的MFA,以及最近DeepSeek提出的NSA和Kimi提出的MoBA。

到底这么多的Attention为什么被提出,它们都意图解决什么问题。 我们用接下来的两个章节来解释一下。

二,Attention 是什么

为了更容易理解我们接下来的内容,我们先来看一下Attention的计算公式:

这个公式到底在做什么,flashattention的论文里说的比较清楚:

结合下图更有利于我们深入的理解attention(11. Self Attention - by Tom Yeh - AI by Hand ✍️

下面会对以上的这张图做分块的解释:

语言模型简单来说就是针对输入的序列计算概率,推测下一个token是什么,新生成的token以及历史上所有的token又做为下一轮生成的输入。 上面这个图是模型多层transformer中的一层的计算过程的具象化展示,主要含有Attention和FFN(MLP)这2个子层。下面几个小节主要解释attention的计算部分。

1, 我们假设该图是transformer第一层,每个token是一个词,这样可以简化讨论。假设有4个词做为输入,被转化为4个向量,也就是图里的Features:x1,x2,x3,x4。 通过和模型的权重矩阵Wq,Wk 相乘,得到一个Q矩阵,也就是图里的q1,q2,q3,q4 这4个销向量。得到一个K矩阵也就是k1,k2,k3,k4 这4个销向量;

2, Q,K 通过矩阵乘法计算得到一个矩阵S

3,然后通过softmax 等计算得到矩阵P

4,输入x1,x2,x3,x4通过和模型的权重矩阵Wv相乘得到v1,v2,v3,v4 也就是矩阵V。 最后矩阵P和V相乘得到最终的输出向量z1,z2,z3,z4,也就是矩阵O

三,为什么能做和要做前言里提到的优化工作

从第二章节不难看出(Decoder-only的attention计算和上面略有不同,不影响计算复杂度讨论),如果不做任何优化,生成每一个token的计算复杂度是O(n^2),最终生成的序列全局计算复杂度是O(n^3)。对于上下文这个计算复杂度肯定是无法接受的。

1,所以,直觉上提升Attention的性能的做法是降低它的计算复杂度.

kv caching就是为了解决这个问题将单个token的计算复杂度降低到O(n) (n为当前序列长度),全局的复杂度就下降到O(n^2)极大的提升了性能。比如下图开不开起kv caching 作者观察到的。 kv caching 是一个空间换时间的做法,那么我们需要付出的就是更多的显存空间用来做k,v 的缓存。

kv caching 为什么可行参考下面2个图(https://x.com/_avichawla/status/1890288542322221206):

大体来说就是Decoder 里的attention计算由于当前token只参考之前的token而不向后参考,所以如第二章节的历史上的z向量可以在新token生成时不更新,Zk = (Qk * K) * V。

2,kv caching 极大的提升了性能,但是需要大量的cache 空间。比如上图所示看Llama3 70B,4k token需要10.5 GB 的cache空间,这个是极大的开销。为了降低cache的压力,开头提到的MQA,GQA等被提出,通过降低k v 的数量来降低对显存的消耗。同时模型能力没有出现明显的下降,所以得到了广泛的使用。DeepSeek 的MLA 则是通过压缩k v 大小的方式,进一步降低了cache 的压力,这里的风险是信息的损失会导致模型性能的下降,但从实际的结果来看,这个风险目前被处理得很好。

3,即使做了1,2的优化,问题就不存在了吗?当然不是,对于更长的上下文cache的压力还是很大,计算压力还是O(n^2) (n 为序列长度),这样模型实际是很难继续朝着长上下文scaling的。2和其他的一些研究表attention是稀疏的,序列越长attention越稀疏,像DeepSeek 的NSV,Kimi的MoBA以及别的一些研究就诞生了,通过这些研究我们可以进一步降低计算量和cache大小同时不降低模型性能。

参考资料:

https://arxiv.org/pdf/1706.03762

https://arxiv.org/pdf/2106.04554

https://arxiv.org/pdf/2205.14135

https://arxiv.org/pdf/2502.11089

11. Self Attention - by Tom Yeh - AI by Hand ✍️

The Annotated Transformer

https://x.com/_avichawla/status/1890288542322221206

### 大模型长文本处理中的技术实现方法 大模型长文本处理方面面临的主要问题是上下文窗口限制[^1]。这种限制使得传统的大规模预训练语言模型难以有效处理超长序列数据。然而,随着研究的深入和技术的进步,已经出现了多种针对这一问题的技术解决方案。 #### 一、注意力机制优化 为了应对长文本带来的计算复杂度增加的问题,研究人员提出了稀疏注意力机制(Sparse Attention Mechanism)。该机制通过减少全连接注意力矩阵的计算量来提高效率。例如,Transformer-XL 和 Longformer 是两种典型的改进方案。其中 Transformer-XL 利用了相对位置编码和分段循环策略扩展了模型的记忆能力[^3];而 Longformer 则引入了一种局部加全局的混合注意力模式,允许模型同时关注短距离依赖关系以及重要远距离信息[^2]。 #### 二、滑动窗口与分块策略 另一种常见的做法是对输入序列进行分割成较小的部分并分别建模后再组合起来形成最终表示向量。这种方法可以显著降低内存消耗同时也保持较好的性能表现。具体来说,可以通过固定大小的滑动窗或者不重叠区域划分等方式来进行操作,并且还需要设计合理的边界融合逻辑以便更好地捕捉跨区块之间的关联特性。 #### 三、基于记忆网络的方法 除了上述提到的一些架构层面调整外, 还有利用外部存储器增强内部状态表达力的思想被广泛应用于解决此类难题上 。比如 MemN2N (Memory Network with Neural Networks), 它能够动态更其工作集从而适应不同长度的任务需求; 另外还有 Recurrent Entity Network 等变体形式也展示出了不错的效果 . ```python class SparseAttention(nn.Module): def __init__(self, dim_model, num_heads=8, window_size=512): super(SparseAttention, self).__init__() assert dim_model % num_heads == 0 self.num_heads = num_heads self.window_size = window_size def forward(self, Q, K, V): # Implement sparse attention logic here... pass ``` 以上介绍了几种主流用于提升大型神经网络框架对于海量连续型资料加工效能的关键措施之一部分例子而已 , 实际当中还存在着更多创性的思路等待我们去发掘尝试 .
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值