从MHA、GQA、MQA到MHLA,注意力的几种处理方式

注意力机制定义

传统的注意力机制QK不同源。

而在自注意力机制中,QKV都来自于同一组元素,是想让机器注意到整个输入中不同部分之间的相关性。

计算方式:
Attention(Q,K,V)=Softmax(Q⋅KTdk)⋅V Attention(Q,K,V)=Softmax(\frac{Q·K^T}{\sqrt{d_k}})·V Attention(Q,K,V)=

### 分组查询注意力GQA)定义解释 #### 什么是分组查询注意力? 分组查询注意力(Grouped Query Attention, GQA)是对传统多头注意力(Multi-Head Attention, MHA)和多查询注意力(Multi-Query Attention, MQA)的一种扩展形式[^1]。这种机制旨在通过提供计算效率与模型表达能力之间的灵活权衡,从而优化大规模语言模型中的注意力机制。 #### 关键特性 GQA的核心理念在于将查询头划分为若干个小组,在每个小组内部共享一组公共的键(Key, K)和值(Value, V)投影矩阵。具体来说: - 查询头被均匀分配到 \( G \) 组中; - 每个组内的所有查询头共用同一套键和值权重参数; - 这种设计既保留了一定程度上的并行处理优势,又避免了完全依赖单一键值对所带来的潜在质量问题[^2]。 #### 工作原理 当应用GQA时,输入序列会经过线性变换得到不同维度下的表示向量作为各个头部的工作基础。对于每一个查询组而言,其对应的键和值是从整个输入空间映射而来而非仅限于该组独有的部分;这意味着尽管存在资源共享的情况,但依然能捕捉到来自全局的信息交互模式[^3]。 ```python import torch.nn as nn class GroupedQueryAttention(nn.Module): def __init__(self, d_model, num_query_groups, n_heads_per_group=8): super().__init__() self.num_query_groups = num_query_groups total_heads = num_query_groups * n_heads_per_group # Define key and value projections shared within each group self.kv_projections = nn.Linear(d_model, d_model * 2) # Define query projection for all heads self.q_projection = nn.Linear(d_model, d_model * total_heads) def forward(self, x): batch_size, seq_len, _ = x.size() kvs = self.kv_projections(x).reshape(batch_size, seq_len, 2, -1) queries = self.q_projection(x).view(batch_size, seq_len, self.num_query_groups, -1) keys, values = kvs.chunk(2, dim=-2) return keys, values, queries ``` 此代码展示了如何构建一个简单的基于PyTorch框架的GQA层结构实例。这里`num_query_groups`指定了要创建多少个独立的查询组,而`n_heads_per_group`则决定了每组内含有的查询头数目[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值