读完本文你将对大模型中常用的 KV 缓存机制提出的动机、原理、实现、Transformers 库中的 past_key_values 参数,以及它在张量并行下的应用有着清晰的认识,而这也是 LLaMA 系列模型当中将要用到的关键技术。
我们都知道,现在各个大模型在整体的网络结构上基本上并没有太大差异,都是基于 Transformer 中的 Docoder 改进而来,所以基本也就是对其中的模块或使用顺序进行了替换或优化,因此为了后面能够更好的理解 LLaMA 2 模型的原理,掌柜依旧会先逐一介绍其中各个小的模块的原理及实现,然后再来整体看 LLaMA 2 模型。

使用 KV 缓存自注意力计算过程
因此,在今天这篇文章中,掌柜想要和大家介绍的便是大模型当中都会使用到的 Key-Value Cache ,简称 KV Cache。
关键字:大模型、LLaMA、DeepSeek、KV缓存、多头注意力机制、KV Cache
1 动机
由于在生成式模型(如 Transformer 解码器)的推理过程中,token 是逐时刻解码生成得到,每个时刻新生成 token 时都需要将先前所有已生成的 tokens 重新输入到模型中,所以这就会带来巨大的计算冗余和重复,使得解码过程越来越慢。
例如在解码生成第 时刻时,会将第 到 时刻的 token 都输入到模型中进行编码;而在生成第 时刻时,又会将第 到 时刻的 token 都输入到模型中进行编码,其中第 到 时刻就会涉及到重复计算的问题。

图 1. 推理解码过程图
如图1所示,模型输入 prompt 的长度为 3,在对 时刻解码时会将整个 prompt 作为输入然后解码得到一个长度为 3 的张量,并取最后一个时刻作为 时刻的生成结果;在对 时刻解码时,会将 时刻的输入与输出拼接起来作为输入生成 时刻的结果;后续以此类推。
可以看出整个过程中会涉及到大量重复和冗余的计算。
注意这里的两个关键词“重复”冗余“”,重复代表着需要去重,而冗余则是要去掉。
有人可能会问,那训练阶段为什么不会呢?
这因为在训练阶段是一次性将整个样本输入到模型中(通过注意力掩码 attention mask 来使得编码当前时刻注意力只能关注到过去的信息而不能看到未来的信息),且所有时刻的输出都将用于损失计算,所以不存在重复或无用计算的问题。
因此,为了避免这种重复和冗余的计算的工作,Key-Value 缓存应运而生。简单来说,它在每次的自注意力计算过程中,都会将此时计算得到的 Key 和 Value 进行缓存,后续在生成下一个时刻的 token 时,模型只需对当前时刻输入的一个 token 计算隐藏状态,再把缓存与该隐藏状态拼接即可得到完整的隐藏状态 Key 和 Value,这样就能显著加快推理速度,而这在长序列或交互式应用中优势非常明显。
2 不采用 KV 缓存推理过程
在介绍完 KV 缓存出现的动机以后,我们再来看 KV 缓存的原理到底是什么样的。
不过为了让大家更好的理解整个过程,掌柜先来带大家通过图示的方式来回顾一下没有 KV 缓存时的完整计算过程,以便稍后将两者进行对比,以便更容易理解。
2.1 计算原理图解
下面,我们依旧以图1中的情境为例进行说明。
首先,我们需要明白的是当模型训练完成以后,在推理过程中各个权重参数是固定不变的,因此对于同样的输入部分,其输出结果也是不变的,这一前提我们要知道。
假定现在有一个训练完成的模型将用于推理任务,为了便于介绍我们只考虑涉及到 KV 缓存的部分,即自注意力机制的计算过程,如图2所示。

图 2. 不采用 KV 缓存自注意力计算过程,‘()’ 表示 Softmax 操作
在图2所示的示例中,原始 prompt 输入为一个长度为 3 的序列,输入到模型已经首先完成 3 个线性变换分别计算得到 、 和 ,进一步完成注意力权重的计算已经最终的输出,即最上方的结果。这里需要注意的是,因为这是在推理阶段,所以在对输入序列进行编码时需要进行 attention mask 操作,以保证当前时刻不能看到未来信息。在得到第 时刻的整个输出以后,模型将会取结果的最后一个时刻作为 时候的生成结果。
此时,我们可以得到第一个结论:在对第 个时刻解码时,生成内容中前 个时刻的结果是无用的。
进一步,开始依次进行后续时刻的解码生成,如图3所示。

图 3. 不采用 KV 缓存自注意力计算过程,‘()’ 表示 Softmax 操作
在图4中,模型对第 时刻解码时,会将整个 prompt 以及到当前时刻为止已经生成的内容拼接起来作为输入来生成当前时刻的结果。同理,此时的输入首先将完成 3 个线性变换分别计算得到 、 和 ,然后进行后续计算。此时我们可以发现,在第 时刻时,其输入序列的前 3 个 token (图中灰色部分)在第 时刻中已经分别进行过了一次同 、 和 的线性变换,也就是说此时 、 和 这3个矩阵的前三行都是之前已经计算过的(图中红框中的内容),这里是在重复计算。
进一步,得到 、 和 后将完成自注意力计算过程并得到最后的输出。这里需要注意的是,因为此时已经是逐时刻进行解码,模型可以看到当前时刻之前的所有信息,所以不再需要进行 attention mask 操作。同时,我们已经可以知道,对于 的输出结果,依旧只会取最后一个时刻对应的结果作为 时刻生成的内容,其余时刻的结果无用。
当第 时,其生成过程完全与上述过程一致,大家可以自行默想。
此时此刻掌柜相信大家已经发现了其中的门道:
① 如果我们每次在进行解码生成时,都将当前时刻计算得到的 和 进行缓存,那么在下一时刻进行解码时就不需要再重复计算对应的结果,这就是解决前面提到的“重复计算”的问题;
② 因为只取每个时刻输出结果的最后一个时刻作为当前时刻的生成结果,所以在①的基础上,当前时刻的输入仅使用上一时刻的输出的最后一个时刻即可,不需要和先前的输入拼接,这就是解决前面提到的“冗余计算”的问题;
③ 基于对于②的理解,所以我们不需要对 进行缓存,因为后面根本用不到。
上面这两点就是 KV 缓存的核心思想。下面我们先来通过一个示例来模拟一下不采用 KV 缓存推理时的整个过程。
2.2 从零开始实现
首先,我们来快速实现多头注意力机制的计算过程,代码如下所示:
1classAttention(nn.Module):
2def__init__(self, embed_dim=64, num_heads=8):
3 super().__init__()
4 self.num_heads = num_heads
5 self.embed_dim = embed_dim
6 self.head_dim = embed_dim // num_heads
7assert self.head_dim * num_heads == self.embed_dim
8 self.wq = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
9 self.wk = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
10 self.wv = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
11 self.wo = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
12
13defforward(self, x: torch.Tensor, mask=None):
14 bsz, seq_len, _ = x.shape
15 xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
16# [ bsz, seq_len, embed_dim] @ [embed_dim, embed_dim] = [ bsz, seq_len, embed_dim]
17 xq = xq.view(bsz, seq_len, self.num_heads, self.head_dim) # [bsz, seq_len, num_heads, head_dim]
18 keys = xk.view(bsz, seq_len, self.num_heads, self.head_dim) # [bsz, seq_len, num_heads, head_dim]
19 values = xv.view(bsz, seq_len, self.num_heads, self.head_dim) # [bsz, seq_len, num_heads, head_dim]
20 print(f"keys(values) shape [bsz, seq_len, num_heads, head_dim]: {values.shape}")
21 xq = xq.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim]
22 keys = keys.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim]
23 values = values.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim]
24 scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
25if mask isnotNone: # [1, 1, seq_len, seq_len]
26 scores = scores + mask
27 scores = F.softmax(scores.float(), dim=-1).type_as(xq)
28 output = torch.matmul(scores, values) # [bsz, num_heads, seq_len, head_dim]
29 output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
30return self.wo(output)
上述代码就是多头注意力的完整实现过程,掌柜已经对相关变量的形状进行了清晰地标注,相信大家对这部分内容已经非常熟悉,掌柜在这里就不再赘述。
2.3 推理过程模拟
进一步,我们来完模拟整个解码过程,示例代码如下:
1if __name__ == '__main__':
2 start_pos = seq_len = 3
3 bsz = 3
4 embed_dim, num_heads = 4, 2
5 total_len = 10
6 input_embeddings = torch.randn([bsz, seq_len, embed_dim])
7 attn = Attention(embed_dim=embed_dim, num_heads=num_heads, max_batch_size=10)
8for cur_pos in range(start_pos, total_len):
9 print(f" =========== decoding at pos: {cur_pos} ============")
10 print(f"input_embeddings shape [bsz, seq_len, embed_dim]: {input_embeddings.shape}")
11 _, seqlen, _ = input_embeddings.shape
12 mask = None
13if cur_pos == start_pos:
14 mask = torch.full((1, 1, seqlen, seqlen), float("-inf"))
15 mask = torch.triu(mask, diagonal=1)
16 print(mask)
17 print(f"mask shape [1, 1, seq_len, seq_len]: {mask.shape}]")
18 output = attn(input_embeddings, mask) # [bsz, seq_len, embed_dim]
19 print(f"attention output shape [bsz, seq_len, embed_dim]: {output.shape}")
20 next_token_hidden = output[:, -1].unsqueeze(1) # [bsz, 1, embed_dim]
21 print(f"next_token_hidden shape [bsz, 1, embed_dim] : {next_token_hidden.shape}")
22 input_embeddings = torch.cat([input_embeddings, next_token_hidden], dim=1)
在上述代码中,第2行定义初始序列的长度,也就是 prompt 的长度, 其中 start_pos 表示从 时刻开始解码生成。第 3~5 行是定义模型的相关参数。第6行是随机生成一个输入序列经过 embedding 后的结果,将作为初始输入。第7行是实例化一个多头注意力模块,这里我们将其简单的看成是解码器。第8行是开始逐时刻进行解码。第11~17行是构建对 prompt 编码时所需要的 attention mask 输入。第18、20行是得到当前时刻的解码输出,然后取最后一个时刻作为结果,并将其扩维到 [bsz, 1, embed_dim]。第22行则是将当前时刻的输入和结果拼接起来,作为下一个时刻的输入,并再次解码生成。
理解上述代码的时候,建议对照上面的图示过程。以上完整示例代码可参见 Code/C06_SelfAttention/C01_attention_no_kv_cache.py 文件。
2.4 输出结果分析
在上述代码执行结束以后,将会得到类似如下结果:
1 =========== decoding at pos: 3 ============
2 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 3, 4])
3 tensor([[[[0., -inf, -inf],
4 [0., 0., -inf],
5 [0., 0., 0.]]]])
6 mask shape [1, 1, seq_len, seq_len]: torch.Size([1, 1, 3, 3])]
7 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 3, 2, 2])
8 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 3, 4])
9 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
10
11 =========== decoding at pos: 4 ============
12 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 4, 4])
13 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 4, 2, 2])
14 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 4, 4])
15 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
16
17 =========== decoding at pos: 5 ============
18 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 5, 4])
19 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 5, 2, 2])
20 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 5, 4])
21 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
22
23 =========== decoding at pos: 6 ============
24 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 6, 4])
25 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 6, 2, 2])
26 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 6, 4])
27 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
28
29 =========== decoding at pos: 7 ============
30 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 7, 4])
31 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 7, 2, 2])
32 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 7, 4])
33 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
34
35 =========== decoding at pos: 8 ============
36 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 8, 4])
37 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 8, 2, 2])
38 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 8, 4])
39 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
40
41 =========== decoding at pos: 9 ============
42 input_embeddings shape [bsz, seq_len, embed_dim]: torch.Size([3, 9, 4])
43 keys(values) shape [bsz, seq_len, num_heads, head_dim]: torch.Size([3, 9, 2, 2])
44 attention output shape [bsz, seq_len, embed_dim]: torch.Size([3, 9, 4])
45 next_token_hidden shape [bsz, 1, embed_dim] : torch.Size([3, 1, 4])
对于上述结果,大家可以自行进行分析,关键一点就是观察各个变量的输出形状。
看到这里,掌柜想要给大家抛出一个疑问,在真正的推理场景中不同用户同时向模型发起请求,那么同一个时间段内模型接收到的多个样本长度肯定是不一样的,那么此时将其作为一个 batch 输入到模型中进行解码生成应该如何处理呢?
换句话说,在一个 batch 中,prompt 的初始长度不同,如何根据 prompt 同时生成多个样本的输出内容?
大家可以先想一想,在介绍 LLaMA 2 完整实现过程时,掌柜再来交这部分作业。
普通人如何抓住AI大模型的风口?
领取方式在文末
为什么要学习大模型?
目前AI大模型的技术岗位与能力培养随着人工智能技术的迅速发展和应用 , 大模型作为其中的重要组成部分 , 正逐渐成为推动人工智能发展的重要引擎 。大模型以其强大的数据处理和模式识别能力, 广泛应用于自然语言处理 、计算机视觉 、 智能推荐等领域 ,为各行各业带来了革命性的改变和机遇 。
目前,开源人工智能大模型已应用于医疗、政务、法律、汽车、娱乐、金融、互联网、教育、制造业、企业服务等多个场景,其中,应用于金融、企业服务、制造业和法律领域的大模型在本次调研中占比超过 30%。

随着AI大模型技术的迅速发展,相关岗位的需求也日益增加。大模型产业链催生了一批高薪新职业:

人工智能大潮已来,不加入就可能被淘汰。如果你是技术人,尤其是互联网从业者,现在就开始学习AI大模型技术,真的是给你的人生一个重要建议!
最后
只要你真心想学习AI大模型技术,这份精心整理的学习资料我愿意无偿分享给你,但是想学技术去乱搞的人别来找我!
在当前这个人工智能高速发展的时代,AI大模型正在深刻改变各行各业。我国对高水平AI人才的需求也日益增长,真正懂技术、能落地的人才依旧紧缺。我也希望通过这份资料,能够帮助更多有志于AI领域的朋友入门并深入学习。
真诚无偿分享!!!
vx扫描下方二维码即可
加上后会一个个给大家发

大模型全套学习资料展示
自我们与MoPaaS魔泊云合作以来,我们不断打磨课程体系与技术内容,在细节上精益求精,同时在技术层面也新增了许多前沿且实用的内容,力求为大家带来更系统、更实战、更落地的大模型学习体验。

希望这份系统、实用的大模型学习路径,能够帮助你从零入门,进阶到实战,真正掌握AI时代的核心技能!
01 教学内容

-
从零到精通完整闭环:【基础理论 →RAG开发 → Agent设计 → 模型微调与私有化部署调→热门技术】5大模块,内容比传统教材更贴近企业实战!
-
大量真实项目案例: 带你亲自上手搞数据清洗、模型调优这些硬核操作,把课本知识变成真本事!
02适学人群
应届毕业生: 无工作经验但想要系统学习AI大模型技术,期待通过实战项目掌握核心技术。
零基础转型: 非技术背景但关注AI应用场景,计划通过低代码工具实现“AI+行业”跨界。
业务赋能突破瓶颈: 传统开发者(Java/前端等)学习Transformer架构与LangChain框架,向AI全栈工程师转型。

vx扫描下方二维码即可

本教程比较珍贵,仅限大家自行学习,不要传播!更严禁商用!
03 入门到进阶学习路线图
大模型学习路线图,整体分为5个大的阶段:

04 视频和书籍PDF合集

从0到掌握主流大模型技术视频教程(涵盖模型训练、微调、RAG、LangChain、Agent开发等实战方向)

新手必备的大模型学习PDF书单来了!全是硬核知识,帮你少走弯路(不吹牛,真有用)

05 行业报告+白皮书合集
收集70+报告与白皮书,了解行业最新动态!

06 90+份面试题/经验
AI大模型岗位面试经验总结(谁学技术不是为了赚$呢,找个好的岗位很重要)

07 deepseek部署包+技巧大全

由于篇幅有限
只展示部分资料
并且还在持续更新中…
真诚无偿分享!!!
vx扫描下方二维码即可
加上后会一个个给大家发

845

被折叠的 条评论
为什么被折叠?



