论文下载地址:https://arxiv.org/pdf/2305.13245v3.pdf
代码下载地址:https://github.com/fkodom/grouped-query-attention-pytorch
多查询注意力(Multi-Query Attention)详解
目录
前面一篇文章我们已经讲过了关于多头注意力MHA(Multi-Head Attenti)以及多查询注意力机制,多查询注意力是基于多头注意力进行改进的,多查询注意力采用key和value共享的方式,从而在每一次加载时,减少内存的访问以及计算。虽然多查询注意力减少了内存的访问以及计算量。但是在最终结果比多头注意力要差,而分组查询注意力GQA(Group-query Attention)兼顾了效率和性能,也就是将key和value以及对应的query进行分组计算注意力,在每一个组中采用多查询注意力计算法方式。
提出目的和方法
提出目的
多头注意力机制(MQA),它只使用单一的键值对,极大地加快了解码器的推理速度。然而,MQA 可能会导致质量下降,而且可能不适合仅为更快的推理而训练单独的模型。
提出方法
提出了一种将多头语言模型的检查点进行预训练的配方,结合 5% 原始预训练计算量,并引入分组查询注意力(GQA),这是多头注意力的一般化,使用一个中间(多于一个,少于查询头总数)数量的键值头。结果显示,经过预训练的 GQA 在保持与多头注意力相当的速度的同时,实现了接近多头注意力的性能。
阐述已有方法存在问题以及改进
自回归解码器推理是Transformer模型的主要性能瓶颈,这源于每个解码步骤都需要加载解码器权重及所有注意力键值对所带来的内存带宽开销。通过采用多查询注意力机制(Multi-Query Attention, MQA)——即使用多查询头单键值头的设计——可显著降低键值对加载的内存带宽需求。
然而,多查询注意力可能导致模型质量下降和训练不稳定性,且难以同时训练兼顾高质量和高效推理的独立模型。尽管部分语言模型(如PaLM)已采用多查询注意力,但包括T5和LLaMA在内的许多公开模型仍保持标准多头注意力架构。
本研究为加速大语言模型推理提出两项贡献:
首先,本文证明只需投入少量额外训练计算量,即可将多头注意力(MHA)检查点升级改造(uptrain)为多查询注意力版本,这种经济高效的方法能同时保留原始MHA检查点的高质量特性。
其次,本文提出分组查询注意力(Grouped-Query Attention, GQA),这种介于多头与多查询注意力之间的混合架构,通过为每组查询头分配共享键值头,在保持接近多头注意力质量的同时,实现与多查询注意力相当的推理速度。
组查询注意力
综合实验
局限性
本文重点针对键值对加载过程中的内存带宽开销进行优化,该问题在生成长序列时尤为突出。但由于长文本生成质量本身难以评估(如摘要任务采用的ROUGE指标存在固有缺陷),本文无法完全确信当前的效率-质量权衡是最优解。实验存在以下局限性: