SnapKV
摘要
大型语言模型(LLMs)在处理长上下文方面取得了显著进展,其中键值(KV)缓存在提升模型性能中起到了关键作用。然而,随着输入长度的增加,KV缓存的增长对内存和时间效率提出了挑战。为了解决这个问题,本文提出了SnapKV,一种创新的、无需微调的KV缓存压缩方法,在保持实际应用中相似性能的情况下,有效地缩小了KV缓存大小。
我们发现,模型中的每个注意力头在生成过程中始终专注于特定的提示注意力特征。同时,这种稳定的模式可以从提示末端的“观察”窗口中获得。基于这一洞察,SnapKV通过自动选择每个注意力头的重要KV位置来压缩KV缓存。我们的方法显著减少了长输入序列处理时不断增长的计算开销和内存占用。具体而言,SnapKV在处理16K tokens的输入时,解码速度提高了3.6倍,内存效率提升了8.2倍,同时在16个长序列数据集上的性能与基线模型相当。此外,SnapKV可以在单块A100-80GB GPU上处理最多380K的上下文tokens,使用 HuggingFace 实现并进行少量修改,仅在“Needle-in-a-Haystack”测试中表现出微小的精度下降。进一步的综合研究表明,SnapKV在实际应用中具有潜力。
1 引言
挑战与动机
随着大型语言模型(LLM)技术的发展,许多领先的模型开始能够处理更长的上下文,这克服了上下文维护和注意力机制可扩展性方面的技术难点。例如,GPT-4 和 Command-R 能够处理长度达128K的上下文,Claude-3 能达到200K,而 Gemini-Pro-1.5 则支持高达1M的上下文长度。尽管这些模型展示了强大的处理能力,但它们在应对长提示输入时依然存在显著的挑战,尤其是在KV缓存效率方面。在推理阶段,随着提示长度的增加,每步解码的延迟会线性增长。这是因为模型需要对过去的KV执行注意力计算,以生成当前的响应。随着上下文长度增加,所需的KV缓存也显著增长,从而导致更高的内存需求。为了满足这种内存要求,硬件的负担随之增加,限制了模型的可扩展性。
缓解这些问题的方法有很多,例如在生成阶段KV缓存被移除[ 5-8 ]。然而,这些方法大多在长上下文环境下缺乏详细的评估。此外,他们主要关注于压缩解码步骤中附加的KV缓存,而忽略了为提示压缩KV缓存的现实问题,这通常是内存效率的瓶颈。在实际应用中,如聊天机器人和代理,其中提示范围从多轮对话到广泛的文章或代码库[ 1、9、10],提示往往比生成的响应(如摘要和代码片段)大得多,从而产生显著的推理延迟和内存使用开销。额外的挑战在于压缩KV缓存以获得如此庞大的提示,同时又不丢失精确生成的关键信息,特别是在具有各种噪声上下文的场景中。
在实际应用中,这种解码延迟和内存占用的提升对用户体验产生了负面影响。例如,聊天机器人和智能代理通常需要处理多轮对话或长篇文章甚至代码片段,这些提示往往比生成的响应内容更长。由于长提示会显著增加解码延迟和内存开销,模型需要一种有效的方法来应对这些挑战,以确保模型在处理长上下文时既高效又准确。
SnapKV的核心思想
为了应对上述挑战,本文提出了一种称为 SnapKV 的创新技术。这项技术通过在生成过程中有效地压缩KV缓存,显著减少了内存和计算资源的消耗,从而解决了长上下文处理中的瓶颈问题。我们的核心发现是,模型的每个注意力头在生成过程中会始终聚焦于特定的提示注意力特征,这一模式可以通过在提示末端设置“观察窗口”来提取。
基于这一发现,SnapKV方法设计了一个无需微调的算法来识别这些重要的注意力特征,并对KV缓存进行选择性压缩,以减少内存占用和提高解码效率。具体而言,SnapKV能够自动筛选出每个注意力头在提示中所关注的关键KV位置,从而显著压缩KV缓存的体积。这样在处理长输入序列时,不仅减少了计算开销,同时也减小了内存使用。
SnapKV的设计可以在保持生成质量的前提下,将KV缓存的大小控制在一个恒定的范围内,这种创新使得模型可以在处理长提示时始终保持稳定的解码速度和内存效率。与传统的KV缓存管理方法相比,SnapKV不依赖于额外的模型微调或复杂的算法调整,具有很强的适用性和实用性。
通过这些观察,我们发现了一个重要的注意力分配现象:只有一部分提示token传达了对响应生成至关重要的信息,而这些token在生成过程中保持不变。为了验证这一发现的稳健性,我们设计了一系列实验,涵盖了不同长度、格式和内容的多样化提示。基于我们的观察,我们衍生出了一个创新而直观的方法,SnapKV,它可以智能地识别注意力分配模式,并在不牺牲模型准确性的情况下压缩长序列提示的KV缓存。凭借其全面的设计,SnapKV在各种数据集上展示了其有效性,并且只需进行少量代码调整,就可以轻松地集成到流行的深度学习框架中。我们的贡献如下:
-
我们设计了实验来探索生成过程中的注意力分配模式,重点关注两个关键问题:
- 输入序列token是否存在一致的注意力分配模式?
- 在生成阶段之前识别这种模式是否可行?
-
我们的发现表明,对于LLMs,大多数输入序列token的注意力分配在生成过程中保持一致。因此,LLMs在生成之前就知道你在寻找什么。
-
受到上述观察的启发,我们开发了一个高效且无需微调的算法,SnapKV,它可以有效识别关键注意力特征,并相应地压缩KV缓存,并且对模型的修改最小(见图1)。
-
我们在不同的LLMs和长序列数据集上评估了SnapKV。SnapKV在保持与完整KV缓存方法相当的准确性的同时,展示了改进的解码速度和内存效率。同时,我们通过Needle-in-a-Haystack进行压力测试,进一步展示了其内存效率和信息检索能力。
2 相关工作
许多先前的工作通过选择性地丢弃KVs来压缩KV缓存,使用不同的算法。在StreamLLM中,只保留最近的token和注意力汇(前几个token),以减少KV缓存的大小,这使得它丢失了被丢弃的中间token所携带的重要信息。Heavy-Hitter Oracle(H2O)引入了一种策略,该策略基于累积注意力从评分函数中派生出贪婪地丢弃KVs。虽然这种方法有效地压缩了生成过程中添加到缓存中的KVs,但它忽略了对提示KVs的压缩,这对于减少内存和计算开销至关重要。基于类似的概念,自适应KV压缩(FastGen)实现了一个双阶段算法,包括四种KV缓存压缩策略。最初,它通过从提示编码获得的分析结果确定最佳策略。随后,它根据这些策略在生成阶段动态驱逐缓存。然而,它面临着与H2O类似的问题。ScissorHands专注于识别和保留在生成步骤中与前一个token窗口的一致注意力权重模式的至关重要的token。然而,这种方法只关注生成过程中前一个关键token窗口,而忽略了包含生成准确响应所必需信息的广泛提示。这种疏忽可能导致无法从提示中提取详细信息。
总结来说,现有方法没有有效地解决现实世界应用中遇到的挑战,其中提示异常长,但需要准确的信息检索。尽管这些技术可能在生成过程中减少了KV缓存的大小,但它们没有解决理解复杂提示上下文的主要挑战,留下了关键问题未解决。
3 观察
在本节中,我们介绍了关于token生成期间QueryKey矩阵中注意力分配模式的观察。我们的分析使用了Ultrachat样本,这是一个多轮次、高质量的指令数据集,包含140万对话。我们进一步过滤了响应长度大于512且提示长度大于3k的序列。我们的发现可以总结为两个关键观察点:
3.1 生成前的模式可识别
在这个实验中,我们将每个层的输入序列的注意力特征分成多个窗口,每个窗口包含128个token,并分别计算最后20个窗口的平均注意力权重。为了理解输入序列沿注意力分配模式,我们计算了输入序列中重要注意力特征(那些具有高平均注意力权重的)的重叠率。实验结果如图2所示。
我们观察到,输入序列的最后一个窗口与实际生成的注意力分配模式非常相似。
3.2 生成过程中模式一致
我们研究了输入序列中最后一个窗口识别为关键的特征在随后的token生成中是否保持其重要性。在实验中,我们将生成的token分成4个窗口,每个窗口跨越128个token,以计算这些窗口与输入序列的最后一个窗口的平均重叠率。如图3所示,输入序列的活跃注意力特征从最后一个窗口中获得,在生成过程中表现出显著的一致性,这由高重叠率证明。
4 SnapKV
在注意力机制中,提示的增长将显著增加生成的时间复杂度,因为Query-Key矩阵乘法。SnapKV通过在生成期间保持恒定数量的提示KVs来解决这个问题,显著减少了长上下文LLMs的服务时间。为了使我们的方法结构清晰,我们提出了以下术语:
-
提示长度( L p r o m p t L_{prompt} Lprompt):用户提供的输入的总长度。
-
观察窗口( L o b s L_{obs} Lobs):提示的最后部分。这个窗口对于分析不同上下文对注意力分配模式的影响至关重要。
-
前缀长度( L p r e f i x L_{prefix} Lprefix):输入的前缀长度,不包括观察窗口。总的来说,我们有: L p r o m p t = L p r e f i x + L o b s L_{prompt} = L_{prefix} + L_{obs} Lprompt=Lprefix+Lobs
-
投票:计算观察窗口内每个查询的注意力权重的过程,聚合这些权重以突出被认为是最重要的前缀位置。对于单个批次的序列,正式地: C = ∑ i = 0 L o b s W o b s [ : , i , : ] C = \sum_{i=0}^{L_{obs}} W_{obs}[:, i, :] C=i=0∑LobsWobs[:,i,:] I = T o p k ( C , k ) I = Top_k(C, k) I=Topk(C,k)
-
命中率:我们定义在生成过程中超过预定义阈值θ的注意力特征为重要特征。命中率H是通过前一投票过程成功选择的重要特征的数量除以总重要特征的数量。H量化了投票机制的有效性,并按以下方式计算: M v o t e _ o b s = z e r o s _ l i k e ( A c u r ) M_{vote\_obs} = zeros\_like(A_{cur}) Mvote_obs=zeros_like(Acur) M v o t e _ o b s [ I ] = 1 M_{vote\_obs}[I] = 1 Mvote_obs[I]=1 M t h r e s h o l d _ c u r = 1 ( A c u r > θ ) M_{threshold\_cur} = 1(A_{cur} > θ) Mthreshold_cur=1(Acur>θ) O = M t h r e s h o l d _ c u r ∧ M v o t e _ o b s O = M_{threshold\_cur} ∧ M_{vote\_obs} O=Mthreshold_cur∧Mvote_obs H = ∑ O ∑ M t h r e s h o l d _ c u r H = \frac{\sum O}{\sum M_{threshold\_cur}} H=∑Mthreshold_cur∑O
4.1 基于观察窗口的算法
SnapKV的核心方法涉及识别和选择每个头部最关键的注意力特征来创建压缩的KV缓存。列表1显示了SnapKV的PyTorch风格伪代码。总体而言,SnapKV通过两个阶段操作:
-
为重要的先前特征投票。通过上述定义的投票过程(方程2),我们基于观察窗口选择重要的注意力特征。第3节强调了观察窗口内注意力分配模式的一致性,表明这些选择的注意力特征对后续生成也至关重要。此外,我们实现了聚类以保留选定注意力特征周围的特征(第4.3节)。代码的第8-17行显示了投票过程的伪代码。
-
更新和存储压缩的键和值。我们将选定的注意力特征与观察窗口内的所有特征连接起来,这包含了包含所有必要提示信息的所有特征。代码的第18-24行显示了压缩过程。连接的KVs存储起来,供后续生成使用,从而节省内存使用。
4.2 命中率的鲁棒性分析
为了理解基于观察窗口的算法的鲁棒性,我们在多个长文档QA数据集上分析了其命中率,包括QMSum、Openreview和SPACE。我们探测的模型是Mistral-7B-Instruct-v0.2。总的来说,我们想要回答以下两个问题:
- 提示中的指令性质是否影响命中率?
- 上下文和指令定位是否影响命中率?
4.2.1 模式的上下文依赖性
我们分析了即使提供相同的上下文,指令是否会根据选择的重要特征而变化。我们的实验使用了同一文档的不同指令,并基于包含指令及其相应响应的观察窗口选择了重要特征。然后我们计算了在同一文档内不同指令-响应对选择的重要特征之间的命中率,使用H(Mvote_A, Mvote_B)。通过改变指令,我们观察到不同的指令优先考虑不同的前缀注意力特征,这表明命中率呈下降趋势。我们的发现揭示了LLMs中KV缓存的一个有趣方面:重要的注意力特征随着不同的指令而变化。这种可变性挑战了依赖于恒定加权重要性或固定策略的静态压缩方法的有效性。因此,上下文与相关KV缓存之间的复杂关系强调了需要上下文感知的压缩策略,并突出了SnapKV的能力,即识别这种动态。
4.2.2 对指令位置的不变性
我们的分析还扩展到了指令定位对LLMs的可解释性和重要特征选择的重要性。我们计算了使用与之前实验相同的观察窗口大小的响应的平均命中率。我们的结果如图5所示,表明在所有三个数据集中,无论指令是位于广泛补充上下文之前还是之后,命中率都一致地高。这种一致性表明,SnapKV能够识别注意力分配模式,无论问题的位置如何。
4.3 通过池化高效聚类
在LLMs中,信息检索和生成依赖于高注意力权重的特征,并通过对上下文使用感应头来补充其余特征。因此,简单地选择顶部特征只保留了部分细节,然后失去了信息的完整性。例如,这种压缩可能导致LLMs只检索到电话号码的国家代码,并幻想出其余部分。我们的实验还揭示了仅选择具有最高权重的特征是不足够的。这种稀疏选择冒着损害特征之间包含的上下文完整性的风险,从而降低了准确性。基于这些见解,我们提出了一个细粒度的聚类算法,使用池化层,如代码的第13行所示。
5 实验
在我们的实验设置中,我们探索了SnapKV在能够处理扩展提示序列上下文的模型中的性能。首先,我们对LWM-Text-Chat-1M进行了压力测试,并基准测试了其速度,这是关于其上下文长度的最新技术。然后,我们对Mistral-7B-Instruct-v0.2进行了消融研究,以了解池化对模型信息检索性能的影响。最后,我们展示了SnapKV可以与其他加速策略(如并行解码)一起使用。
5.1 LWM-Text-Chat-1M上的基准测试
LWM-Text-Chat-1M是一个7B指令微调模型,上下文长度可达一百万。在本节中,我们对SnapKV进行了压力测试,并检查了其算法效率。
5.1.1 Needle-in-a-Haystack
Needle-in-a-Haystack测试挑战模型从隐藏在广阔文档中的特定句子(“needle”)中准确检索信息,该句子随机放置。通常,插入提示中间的句子更难检索。为了严格评估SnapKV的能力,我们将文档长度扩展到380ktoken,这是单个A100-80GB GPU可以处理的最长内容。我们将提示KV缓存大小配置为1024,允许SnapKV从提示中选择最关键的1024个注意力特征来生成答案,最大池化核大小为5,观察窗口大小为16,这两个超参数可以根据需要进行定制。图6中Needle-in-a-Haystack测试的引人注目的结果强调了SnapKV在极端长输入上下文中精确管理小细节的潜力,压缩比为380倍。
5.1.2 解码速度和内存限制
我们进一步在不同批量设置下基准测试了LWM-Text-Chat-1M的速度。我们将SnapKV的最大KV缓存大小设置为2048,并固定生成长度为512以确保公平比较。图7显示了不同批量大小下基线实现和SnapKV优化模型的解码延迟比较。实验在A100 80GB GPU上进行。红色虚线表示最新长序列模型的通用上下文长度。
5.2 池化的消融研究
我们在Mistral-7B-Instruct-v0.2上进行了消融研究,以评估我们的池化技术的影响,这是一种通过聚类整合信息的简单但有效的方法。我们使用修改后的LongEval-Lines基准测试进行评估,包括随机生成的对和平均分数。LongEval-Lines比Needle-in-a-Haystack更具挑战性,因为它涉及在相同格式的嘈杂上下文中识别键值对,而在Needle-in-a-Haystack中,相关信息更明显地与其它上下文分离。我们应用最大池化,核大小为5,并使用大小为16的观察窗口,这些超参数可以根据不同的模型进行定制。正如我们的结果(图8)所示,我们发现池化显著提高了检索准确性,与不使用池化的方法相比。我们假设这是因为注意力机制对关键token聚类最初部分的权重更高。通常,大型语言模型倾向于复制最初部分周围的token以保持上下文的完整性。然而,简单地压缩KV缓存破坏了这种机制,可能导致部分正确的结果(图8)。请注意,在我们的实验中,最大池化和平均池化的选择在性能上没有显著差异。
5.3 LongBench上的实验
我们在LongBench上评估了SnapKV在这四个模型上的性能,这是一个多任务基准测试,旨在严格评估各种数据集上的长上下文理解能力。我们选择LWM-Text-Chat-1M(上下文长度一百万)、LongChat-7b-v1.5-32k、Mistral-7B-Instruct-v0.2、Mixtral-8x7B-Instruct-v0.1(上下文长度32k)作为我们的基线。对于每个模型,我们测试了SnapKV的不同设置:将提示中的KV缓存压缩到1024、2048和4096个token。我们使用最大池化,核大小为7,观察窗口大小为32。表1显示了SnapKV与原始实现相比,在16个不同数据集上的模型性能仅有微不足道的下降,即使在提示-KV为1024个token的情况下也是如此。一些模型甚至超过了基线。我们的结果证实了SnapKV能够把握长上下文中的关键信息,并提供详细的综合摘要。此外,我们的结果还表明了SnapKV在压缩提示KV缓存方面的有效性。对于这4个模型,平均输入token长度约为13k。因此,使用1024个SnapKV,平均压缩率达到92%,使用4096个,达到68%,准确性几乎没有下降。我们比较了SnapKV和H2O在LongBench数据集上的性能,以进一步展示SnapKV的性能。为了公平评估准确性,我们将H2O的提示容量设置为4096。如表1所示,SnapKV的性能明显优于H2O。即使在提示KV缓存为1024的情况下,Mistral-7B-Instruct-v0.2上的SnapKV在16个基准测试中的11个上也比4096个缓存的H2O表现更好。
5.4 Command-R上的实验
为了进一步评估SnapKV的性能,我们使用Cohere的Command-R模型进行了实验,这是一个开源模型,拥有35B参数,能够处理长达128k token长度的序列。Command-R旨在处理需要长上下文的复杂任务,如检索增强生成(RAG)。我们在NarrativeQA和修改版的Needle-in-a-Haystack上对Command-R进行了广泛测试,并取得了有希望的结果。为了评估SnapKV对RAG的影响,我们在bioasq、HotpotQA和Cohere的内部基准测试上运行了测试,进一步证明了其有效性。在所有实验中,我们将KV缓存限制为最多4096个token,而池化核大小和窗口大小分别设置为13和64。对于我们的评估,这些超参数提供了2x到32x的KV缓存压缩比,具体取决于序列长度。
5.4.1 Needle-in-a-Haystack
在之前的实验中,注意到Needle-in-a-Haystack评估受到特定上下文使用的严重影响。为了解决这个问题,我们通过改变上下文组成来修改评估,每种长度和深度组合都进行了八次,产生了更可靠的结果。我们观察到,在这种设置下,所有测试模型的得分都比没有上下文洗牌的原始设置略有下降。为了简化,我们对基线模型和带有SnapKV的模型的得分进行了聚合。正如表2所示,即使在128k序列长度的情况下,将SnapKV应用于Command-R也没有性能下降,KV缓存的压缩比为32倍。
5.4.2 检索增强生成(RAG)
我们评估了SnapKV在RAG任务中的有效性,这些任务比合成长上下文任务(如Needle-in-a-Haystack)更复杂,比NarrativeQA等任务更接近实际用例。RAG任务需要根据给定的提示从索引语料库中选择相关文档。扩展的上下文窗口使检索更多文档成为可能,这可以提高模型性能。然而,这也增加了内存需求和延迟,突显了检索范围和系统资源之间的微妙平衡。SnapKV通过减少内存使用量同时提高性能,在这些任务中非常有益。我们使用从20,000到40,000个token的上下文长度数据集评估了SnapKV对RAG的影响。鉴于我们的KV缓存大小为4096,我们实现了5-10倍的压缩。如表3所示,SnapKV在保留了Command-R的98.8%性能的同时表现出色。
5.5 与并行解码的兼容性案例研究
本节提供了将KV缓存压缩与并行解码相结合的新视角。并行解码利用轻量级模型或适配器起草初始token,然后由更大的LLMs验证。这种策略有效地减少了内存开销,这是LLMs的一个关键问题,因为它们的自回归特性使它们比计算要求更高的模型更内存密集。具体来说,在LLMs中,每个解码步骤涉及生成一个token,HBM和缓存之间的权重传递导致了显著的开销。通过在生成期间保持与提示相关的KV缓存的恒定大小,SnapKV提高了生成效率。
图9显示了不同提示长度下的性能,Mistral-7B-Instruct-v0.24在达到128个生成步骤之前进行了最大处理,除非被预先停止。实验使用了QASPER的一个子集,提示固定为让LLM总结论文。采用的截断策略与LongBench标准一致,通过移除中间的上下文以达到所需的序列长度进行基准测试。
结果表明,随着序列长度的延长,Medusa的性能下降,这一挑战被SnapKV有效缓解,与Medusa相比,在10k长度的序列上实现了1.3倍的速度提升,与原生解码相比实现了2.2倍的速度提升。这一改进强调了将KV缓存压缩与并行解码框架相结合以提高LLM效率的潜力,特别是在长上下文场景中。
6 讨论
SnapKV是一个有效而简单的解决方案,它通过压缩KV缓存来减轻处理大量提示的计算和内存负担。我们观察到,在生成过程中,提示中的特定token从每个头部获得一致的注意力,我们的方法不仅检索了关键信息,还提高了处理效率。尽管SnapKV具有其优势,但其应用范围主要限于模型的生成方面,特别是针对生成阶段的KV缓存。这种局限性意味着SnapKV无法扩展模型的长上下文处理能力,如果模型本身在处理长上下文方面存在困难或表现不佳。此外,SnapKV的设计没有涵盖提示的推理处理,这限制了其在系统无法处理长提示长度时的有效性。尽管如此,我们的贡献为社区提供了重要的见解和工具,为管理大规模语言建模的挑战铺平了道路。附录提供了有关并行解码的更多实验和有关生成加速的讨论。