Transformer升级之路:多头潜在注意力机制(MLA)究竟好在哪里?看完这一篇就懂了!!

前言

自从 DeepSeek 爆火后,它所提的 Attention 变体 MLA(Multi-head Latent Attention)也愈发受到关注。

MLA 通过巧妙的设计实现了 MHA 与 MQA 的自由切换,使得模型可以根据训练和推理的不同特性(Compute-Bound or Memory-Bound)选择最佳的形式,尽可能地达到效率最大化。

诚然,MLA 很有效,但也有观点认为它不够优雅,所以寻找 MLA 替代品的努力一直存在,包括我们也有在尝试。

然而,经过一段时间的实验,我们发现很多 KV Cache 相同甚至更大的 Attention 变体,最终效果都不如 MLA。这不得不让我们开始反思:MLA 的出色表现背后的关键原因究竟是什么?

接下来,本文将详细介绍笔者围绕这一问题的思考过程以及相关实验结果。

1、观察

MLA 的主要特点如下:

\1. MLA 在训练阶段是一个 qk_head_dims=(128+64)、v_head_dims=128 的 MHA;

\2. MLA 在解码阶段是一个 qk_head_dims=(512+64)、v_head_dims=512、KV-Shared 的 MQA;

\3. MLA 的 [qc, qr]、[kc, kr] 拼接,可以理解为一种Partial RoPE。

2、猜测

MHA、GQA 常用的 head_dims 是 128,而对于 MLA 来说,不管是从训练看的 128+64,还是从推理看的 512+64,都要大于 128,再结合《突破瓶颈,打造更强大的 Transformer》[2] 的经验,我们有:

猜测 1:增大 head_dims 是 MLA 好的关键之一。

另外,KV-Shared 这个特性,可以在同等 KV Cache 大小下,增大 GQA 的 head_dims 或者 num_groups,所以有:

猜测 2:KV-Shared 是 MLA 好的关键之一。

最后,此前有一些理论和实验显示 Partial RoPE 可能会对效果有正面帮助,所以有:

猜测 3:Partial RoPE 是 MLA 好的关键之一。

3、实验

现在我们通过实验逐一检验以上猜测。

设置

所有实验公共部分的超参数如下:

  1. 类似 LLAMA3 的 Dense 模型;

  2. hidden_size=2048,num_layers=12,num_heads=16;

  3. 优化器是Muon,Attention 部分 per head 更新;

  4. 训练长度为 4096,总 tokens 数为 16B,总训练步数为 16k;

  5. 所有实验都是只改变 Attention,所以参数量不会严格对齐。

Part I

MLA 的 KV Cache 大小是 512+64,约等于 GQA2-128(第一个数字是 num_groups,第二个数字是 head_dims),所以对比的 baseline 为 GQA2-128GQA1-256

为了验证 Partial RoPE,我们增加了 GQA1-256-PR,具体做法是将 Q、K 的 256 dims 分成 192+64 两部分,在 64 上加 RoPE,192 不加。

结果如下:

图片

即:

图片

初步验证了增大 head_dims 和 Partial RoPE 的作用。这样看来,MLA 的设计中,RoPE 和 NoPE 拼接这部分看似无奈的设计,极有可能是它效果优异的关键原因!原论文声称 MLA 甚至优于 MHA,大概率也是因为所对比的 MHA 的 head_dims 只有 128。

Part II

为了进一步验证增大 head_dims 的作用,我们另外跑了 MHA、GQA2-192、MLA-256 三个实验,MHA 是 head_dims=128 的常规 MHA,GQA2-192 是直接增大 GQA2 的 head_dims 到 192,MLA-256 是将 MLA 的 128+64 提升到 192+64,对照如下:

图片

可以看到,MHA 总参数量更多,KV Cache 更是 7 倍于 MLA,但 Loss 才堪堪追平 MLA,这跟 DeepSeek-V2 里边的结论接近。

此外,GQA2-192 优于 GQA2-128,但不如 GQA1-256;MLA 的 head_dims 升到 (192+64) 后,相比 (128+64) 也还能进一步提升效果。这些现象都表明,增加 head_dims 远比增加 num_groups 更有效。

Part III

接下来我们验证 KV-Shared,即 K、V 共享全部或大部分 dims。这里我们主要考虑的替代品是 head_dims 不超过 256 的 GQA,并且控制 KV Cache 的总大小跟 MLA 接近,所以当 KV-Shared 时,我们可以至多可以考虑 GQA2-256。

由于 KV-Shared 跟 RoPE 不完全兼容,参考 MLA 的做法,我们将 256 分成 192+64 两部分,其中:

  1. 192 部分不加 RoPE,在 K、V 间共享;

  2. 64 部分加 RoPE,只用于 K;

  3. V 另外再投影 64 dims,concat 到共享的 192 dims 上去。

这样一来,K、V 的 head_dims 都是 256,KV Cache 总大小是 (192+64+64)2=640,略大于 MLA 的 512+64=576,这个版本我们简记为*“GQA2-(192+64)-S1”**,其实“S1”是“Shared-1”的缩写。

Part IV

另外一种 KV-Shared 的方案是:

  1. 192 部分不加 RoPE,在 K、V 间共享;

  2. 64 部分加 RoPE,同样在 K、V 间共享;

  3. 做 Attention,由于 V 带 RoPE,此时是绝对位置编码效果;

  4. 为了保证相对位置编码,将输出分成 192+64 两部分,64 部分再加一次逆向 RoPE。

这种做法是 K、V 完全共享,KV Cache 大小是 (192+64)*2=512,略小于 MLA。这个版本我们称为 “GQA2-(192+64)-S2”,“S2”是“Shared-2”的缩写,背后的原理是笔者新提出的 VO-RoPE,参考Transformer升级之路:第二类旋转位置编码 [3]。

Part V

另外,根据同样思路补了几个 GQA4 和 GQA1 的实验。所有实验结果汇总如下:

图片

这里 “GQA1-(512+64)-S3” 是按照 MLA 的推理形式实现的 MQA,形式介乎 S1 与 S2 之间,它的主要特点是 head_dims 大。

结果解读:

  1. KV-Shared 的 GQA 自带 Partial RoPE;

  2. KV-Shared 的 GQA2-256,也能超过 MLA;

  3. VO-RoPE 的引入,似乎有利于效果(S1 ≲ S2);

  4. 同等 KV Cache 下,head_dims 越大越好;

  5. GQA2-(192+64)-S2 略微超过 GQA1-256-PR;

\6. GQA4-(128+64)-S2 的 KV Cache 最大,但效果不是最优,再次表明 head_dims 更关键。

关于 KV-Shared,还有两点观察:

  1. 训练过程中,GQA1-256-PR 前期是明显领先 GQA2-(192+64)-S2,但后期被追平甚至略微反先,猜测 GQA1-256-PR 可能有后劲不足的嫌疑;

  2. 如果没有 KV-Shared,GQA 顶多是 GQA1-256,也就是说 head_dims 顶天了 256,但有 KV-Shared 的话,GQA 可以做到 GQA1-512-S,单纯从 head_dims 看,KV-Shared 天花板更高。

Part VI

由于没有严格对齐参数量,可能读者会有“到底是增加参数量还是增加 head_dims 更本质”的疑虑,所以这里补充几个对齐参数量的实验。

这里考虑的对齐参数量的方式有三种:

1. double-heads: 以“GQA2-128 vs GQA1-256”为例,将 GQA2-128 的 num_heads 翻倍,可以让 GQA2-128 的参数量跟 GQA1-256 相同;

2. 缩减 MLP: 缩小 MLP(SwiGLU)的 intermediate_size,也可以使得 GQA1-256 的参数量跟 GQA2-128 大致相同;

3. Q&O LoRA: GQA 的主要参数量来自 Query 和 Output 的投影矩阵,对这两个矩阵改用 LoRA,也可以降低 GQA1-256 的参数量。

实验结果如下:

图片

结果主要分三块:

  1. heads 翻倍相比 head_dims 翻倍,loss 稳定差 0.003 左右;

  2. 缩小 MLP 比 head_dims 减半,loss 稳定优 0.004 左右;

  3. Q&O LoRA 性能损失最小,可以实现 head_dims 翻倍但参数量不增,且 loss 明显降。

结论: 如果从增加参数量角度看,增大 head_dims 可能是效果增益较大的方向,配合 Q&O LoRA 可以实现参数量几乎不增,但收益仍相当。

4、小结

初步结论是:

  1. 增大 head_dims 收益最大;

  2. Partial RoPE 对 Loss 也有一定帮助;

  3. KV-Shared 应该也有一定作用。

这样看来,此前我们一直在 head_dims=128 下找 MLA 的替代品,感觉是起点就先天不足了,难怪一直比不上 MLA。要想追平 MLA,head_dims 应该要 192 起步了,并辅以 Partial RoPE。至于 KV-Shared,也可能有用,但应该还需要更大规模的验证。

5、意义

其实这里边的意义,就看我们换掉 MLA 的决心有多强。

假设 GQA2-(192+64)-S2 可以替代 MLA,但 MLA 也可以升到 256,目前看来 GQA2-(192+64)-S2 比不上 MLA-256 。那么换掉 MLA 的唯二好处是:

  1. 结构更简单,可以方便加 QK-Norm;

  2. 解码阶段的 head_dims 由 512+64 变成了 256,同时 num_groups 变为 2,可以 TP。

最后的最后

感谢你们的阅读和喜欢,作为一位在一线互联网行业奋斗多年的老兵,我深知在这个瞬息万变的技术领域中,持续学习和进步的重要性。

为了帮助更多热爱技术、渴望成长的朋友,我特别整理了一份涵盖大模型领域的宝贵资料集。

这些资料不仅是我多年积累的心血结晶,也是我在行业一线实战经验的总结。

这些学习资料不仅深入浅出,而且非常实用,让大家系统而高效地掌握AI大模型的各个知识点。如果你愿意花时间沉下心来学习,相信它们一定能为你提供实质性的帮助。

这份完整版的大模型 AI 学习资料已经上传优快云,朋友们如果需要可以微信扫描下方优快云官方认证二维码免费领取【保证100%免费

大模型知识脑图

为了成为更好的 AI大模型 开发者,这里为大家提供了总的路线图。它的用处就在于,你可以按照上面的知识点去找对应的学习资源,保证自己学得较为全面。
在这里插入图片描述

经典书籍阅读

阅读AI大模型经典书籍可以帮助读者提高技术水平,开拓视野,掌握核心技术,提高解决问题的能力,同时也可以借鉴他人的经验。对于想要深入学习AI大模型开发的读者来说,阅读经典书籍是非常有必要的。

在这里插入图片描述

实战案例

光学理论是没用的,要学会跟着一起敲,要动手实操,才能将自己的所学运用到实际当中去,这时候可以搞点实战案例来学习。

在这里插入图片描述

面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下

在这里插入图片描述

640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传优快云,朋友们如果需要可以微信扫描下方优快云官方认证二维码免费领取【保证100%免费

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值