分组查询注意力GQA(Grouped-query attention)算法详解

论文下载地址:https://arxiv.org/pdf/2305.13245v3.pdf

代码下载地址:https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py

代码下载地址:https://github.com/fkodom/grouped-query-attention-pytorch

多查询注意力(Multi-Query Attention)详解

目录

提出目的和方法

提出目的

提出方法

阐述已有方法存在问题以及改进

组查询注意力

综合实验

局限性


        前面一篇文章我们已经讲过了关于多头注意力MHA(Multi-Head Attenti)以及多查询注意力机制,多查询注意力是基于多头注意力进行改进的,多查询注意力采用key和value共享的方式,从而在每一次加载时,减少内存的访问以及计算。虽然多查询注意力减少了内存的访问以及计算量。但是在最终结果比多头注意力要差,而分组查询注意力GQA(Group-query Attention)兼顾了效率和性能,也就是将key和value以及对应的query进行分组计算注意力,在每一个组中采用多查询注意力计算法方式。

提出目的和方法

提出目的

        多头注意力机制(MQA),它只使用单一的键值对,极大地加快了解码器的推理速度。然而,MQA 可能会导致质量下降,而且可能不适合仅为更快的推理而训练单独的模型

提出方法

        提出了一种将多头语言模型的检查点进行预训练的配方,结合 5% 原始预训练计算量,并引入分组查询注意力(GQA),这是多头注意力的一般化,使用一个中间(多于一个,少于查询头总数)数量的键值头。结果显示,经过预训练的 GQA 在保持与多头注意力相当的速度的同时,实现了接近多头注意力的性能

阐述已有方法存在问题以及改进

        自回归解码器推理是Transformer模型的主要性能瓶颈,这源于每个解码步骤都需要加载解码器权重及所有注意力键值对所带来的内存带宽开销。通过采用多查询注意力机制(Multi-Query Attention, MQA——即使用多查询头单键值头的设计——可显著降低键值对加载的内存带宽需求。

        然而,多查询注意力可能导致模型质量下降和训练不稳定性,且难以同时训练兼顾高质量和高效推理的独立模型。尽管部分语言模型(如PaLM)已采用多查询注意力,但包括T5LLaMA在内的许多公开模型仍保持标准多头注意力架构。

本研究为加速大语言模型推理提出两项贡献

        首先,本文证明只需投入少量额外训练计算量,即可将多头注意力(MHA)检查点升级改造(uptrain)为多查询注意力版本,这种经济高效的方法能同时保留原始MHA检查点的高质量特性

        其次,本文提出分组查询注意力(Grouped-Query Attention, GQA),这种介于多头与多查询注意力之间的混合架构,通过为每组查询头分配共享键值头,在保持接近多头注意力质量的同时,实现与多查询注意力相当的推理速度。

组查询注意力

综合实验

局限性

本文重点针对键值对加载过程中的内存带宽开销进行优化,该问题在生成长序列时尤为突出。但由于长文本生成质量本身难以评估(如摘要任务采用的ROUGE指标存在固有缺陷),本文无法完全确信当前的效率-质量权衡是最优解。实验存在以下局限性:

1. ​训练成本限制​ ​:未将增量训练的 XXL-GQA 模型与从头训练的对比模型进行比较,因此无法量化两种训练方式的相对性能差异
2. ​架构覆盖范围​ ​:当前评估仅针对编码器 - 解码器架构(如 T5 ),而近期流行的纯解码器模型(如 LLaMA )因自注意力与交叉注意力机制合并, GQA 相对 MQA 的优势可能更为显著

### 分组查询注意力机制的概念与实现 分组查询注意力Grouped-Query Attention, GQA)是一种改进的多头注意力机制变体,其核心在于通过减少键值投影头的数量来平衡计算效率和性能。相比于传统的多头注意力(Multi-head Attention, MHA),GQA 使用较少数量的键值头(Key-Value Heads),通常介于 1 和查询头总数之间[^2]。 #### 关键特性 GQA 的设计目标是在保持较高推理速度的同时提升模型质量。具体而言: - **键值头数量折衷**:GQA 中使用的键值头数小于查询头数,这减少了内存占用并加速了计算过程。 - **性能优化**:经实验验证,GQA 在多项任务上表现出与标准 MHA 类似的性能水平,同时显著优于仅使用单一键值头的 Multi-Query Attention (MQA)。 #### 数学表示 假设输入序列长度为 \(N\),嵌入维度为 \(d_{model}\),则可以定义如下参数: - 查询头数为 \(h_q\) - 键值头数为 \(h_k\) (其中 \(1 \leq h_k < h_q\)) 对于每个位置 \(i\)查询向量 \(Q_i\)、键向量 \(K_j\) 和值向量 \(V_j\) 可分别由输入矩阵线性变换得到。随后,GQA查询分为若干并与对应的共享键值头交互: \[ A^{g} = softmax(\frac{Q_g K^T}{\sqrt{d_k}}) V \] 最终结果通过对各输出进行拼接获得: ```python def grouped_query_attention(Q, K, V, num_groups): """ Q: Query tensor of shape [batch_size, seq_len, heads_q * dim_per_head] K: Key tensor of shape [batch_size, seq_len, heads_kv * dim_per_head] V: Value tensor of shape [batch_size, seq_len, heads_kv * dim_per_head] num_groups: Number of groups for grouping query heads. """ batch_size, seq_len, _ = Q.shape # Reshape and split into groups Q_grouped = Q.view(batch_size, seq_len, num_groups, -1).transpose(2, 3) K_expanded = K.repeat_interleave(num_groups // K.size(-1), dim=-1) scores = torch.matmul(Q_grouped / math.sqrt(K.size(-1)), K_expanded.transpose(-2, -1)) attn_weights = F.softmax(scores, dim=-1) context = torch.matmul(attn_weights, V) return context.reshape_as(Q) ``` 上述代码展示了如何基于 PyTorch 实现分组查询注意力的核心逻辑。 #### 应用场景 GQA 已被广泛应用于现代大型语言模型中,例如 LLaMA2 和 Mistral。这些模型利用 GQA 提升了推理阶段的吞吐量,同时维持较高的精度。 --- ### 注意力机制的空间增强方法 除了 GQA,在卷积网络中也存在其他形式的空间注意力增强技术。例如,“Spatial Group-wise Enhance” 方法提出了一种空间分增强策略,用于改善语义特征学习能力[^3]。该方法通过引入局部区域内的通道间依赖关系,进一步增强了模型捕捉复杂模式的能力。 此外,全局注意力机制也被用来保留更多上下文信息以加强信道-空间相互作用[^4]。这种方法能够有效缓解传统 CNN 对长距离依赖建模不足的问题。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值