MHA(Multi-Head Attention) 与GQA(Grouped Query Attention)的区别

引言

       Grouped Query Attention(GQA,分组查询注意力)和多头注意力机制(Multi-Head Attention,MHA)都是Transformer模型中用于捕获输入序列中不同位置之间关系的注意力机制。然而,它们在实现方式和计算复杂度上有所不同。下面我将详细介绍它们的原理以及它们之间的区别。

1. 多头注意力机制(MHA)

1.1 概念

       多头注意力机制(Multi-Head Attention,MHA)是Transformer模型中的核心组件。它通过并行的多组注意力机制,让模型能够在不同的子空间中关注序列的不同方面,从而加强模型的表达能力。

1.2 工作原理

       输入表示:给定输入序列 X ∈ R n × d model \mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}} XRn×dmodel n n n是序列长度, d model d_{\text{model}} dmodel是隐藏维度。

       线性投影:通过线性变换将输入 X \mathbf{X} X映射为查询( Q \mathbf{Q} Q)、键( K \mathbf{K} K)和值( V \mathbf{V} V):

Q = X W Q , K = X W K , V = X W V \mathbf{Q} = \mathbf{X}\mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}_K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}_V Q=XWQ,K=XWK,V

### MHAGQA MLA 的区别及应用场合 #### 多头注意力机制(Multi-Head Attention, MHA) 多头注意力机制允许模型在不同的表示子空间中并行关注不同位置的信息。每个头独立操作,最终结果通过拼接各头的结果来获得更丰富的特征表达[^1]。 ```python import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads # 定义线性变换层 self.query_linear = nn.Linear(embed_size, embed_size) self.key_linear = nn.Linear(embed_size, embed_size) self.value_linear = nn.Linear(embed_size, embed_size) def forward(self, query, key, value): batch_size = query.size(0) # 对输入进行线性变换 Q = self.query_linear(query) K = self.key_linear(key) V = self.value_linear(value) # 将嵌入维度分割成多个头 Q = Q.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) K = K.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) V = V.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) # 计算注意力分数并加权求 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5) attention_weights = F.softmax(scores, dim=-1) output = torch.matmul(attention_weights, V) # 合并头部并将结果传递给下一个线性层 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_size) return output ``` #### 组化查询注意力机制(Grouped Query Attention, GQA) 为了减少计算量,GQA引入了查询分组的概念,即某些查询可以共享相同的键值矩阵。这减少了重复计算的数量,在大规模数据集上尤其有效[^3]。 ```python class GroupedQueryAttention(nn.Module): def __init__(self, embed_size, num_groups, heads_per_group): super(GroupedQueryAttention, self).__init__() self.embed_size = embed_size self.num_groups = num_groups self.heads_per_group = heads_per_group # 初始化参数... def forward(self, queries, keys, values): # 实现GQA的具体逻辑... pass ``` #### 压缩键值注意力机制(Compressed Key/Value Attention, MLA) MLA进一步优化了资源利用效率,通过对键值向量实施低秩近似压缩处理,从而显著降低了存储开销以及前向传播过程中的运算复杂度。 $$K_{\text{compressed}} = U_K \cdot S_K \cdot V_K^T$$ $$V_{\text{compressed}} = U_V \cdot S_V \cdot V_V^T$$ ```python from scipy.linalg import svd def compress_matrix(matrix, rank): u, s, vh = svd(matrix) compressed = np.dot(u[:, :rank], np.dot(np.diag(s[:rank]), vh[:rank, :])) return compressed class CompressedKeyValueAttention(nn.Module): def __init__(self, embed_size, compression_rank): super(CompressedKeyValueAttention, self).__init__() self.compress_key = lambda k: compress_matrix(k, compression_rank) self.compress_value = lambda v: compress_matrix(v, compression_rank) # 其他初始化... def forward(self, queries, keys, values): compressed_keys = self.compress_key(keys) compressed_values = self.compress_value(values) # 使用压缩后的keysvalues继续执行标准的注意力机制... pass ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值