一、LLaMA的核心改进全景
Meta开源的LLaMA模型凭借其卓越的性能表现成为大模型发展的重要里程碑。相较于标准Transformer架构,LLaMA主要在以下几个方面进行了关键改进:
- 位置编码升级:采用旋转位置编码(Rotary Position Embedding, RoPE)
- 归一化革新:对每个 Transformer 子层的输入进行归一化(Pre-normalization),并使用RMS-Norm替代传统LayerNorm。
- 激活函数优化:引入 SwiGLU 激活函数取代 ReLU 非线性函数。
- 注意力优化(LLaMA 2):引入分组查询注意力(Grouped Query Attention, GQA)
这些改进显著提升了模型的计算效率和长文本处理能力,今天我们来学习分组查询注意力(Grouped Query Attention, GQA)。
其余部件的学习链接持续更新中,欢迎关注:
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之旋转编码RoPE(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之均方根误差标准化RMSNorm(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之SwiGLU激活函数(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之分组查询注意力(Grouped Query Attention)(含代码实现)
二、分组查询注意力(Grouped Query Attention)
2.1 改进动机
传统Transformer使用多头注意力(Multi-Head Attention, MHA),每个头独立生成查询(Query)、键(Key)和值(Value)。虽然MHA能捕捉丰富的上下文信息,但存在以下问题:
- 计算冗余:每个头独立计算Q/K/V,参数量和内存占用高。
- 推理延迟:生成任务中逐token解码时,KV缓存占用内存过大。
**多查询注意力(Multi-Query Attention, MQA)**通过共享所有头的K和V矩阵降低计算量,但牺牲了模型表达能力。
GQA在MHA和MQA之间找到了平衡:将查询头分组,组内共享键和值,既减少计算开销,又保留多粒度语义捕捉能力。
示意图解析:
- Multi-Head Attention(左):每个头独立生成Q/K/V,参数量最大。
- Grouped Query Attention(中):将查询头分为若干组,组内共享K和V,参数量显著降低。
- Multi-Query Attention(右):所有查询头共享同一组K和V,参数量最小但表达能力受限。
2.2 数学原理
给定输入序列 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d,GQA的计算步骤如下:
- 分组查询投影:将
h
h
h 个查询头分为
g
g
g 组,每组包含
h
/
g
h/g
h/g 个头:
Q i = X W i Q , K j = X W j K , V j = X W j V ( i = 1 , … , h ; j = 1 , … , g ) Q_i = X W_i^Q, \quad K_j = X W_j^K, \quad V_j = X W_j^V \quad (i=1,\dots,h; \ j=1,\dots,g) Qi=XWiQ,Kj=XWjK,Vj=XWjV(i=1,…,h; j=1,…,g) - 注意力计算:每组查询与对应的共享键值交互:
Attention ( Q i , K j , V j ) = softmax ( Q i K j T d k ) V j \text{Attention}(Q_i, K_j, V_j) = \text{softmax}\left(\frac{Q_i K_j^T}{\sqrt{d_k}}\right) V_j Attention(Qi,Kj,Vj)=softmax(dkQiKjT)Vj - 输出拼接:将各组输出拼接后线性变换:
GQA ( X ) = Concat ( head 1 , … , head h ) W O \text{GQA}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O GQA(X)=Concat(head1,…,headh)WO
其中, d k d_k dk 为键的维度, W O W^O WO 为输出投影矩阵。
2.3 源码实现
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_dim=768, head_num=4, group_num=2, dropout=0.1):
super().__init__()
assert hidden_dim % head_num == 0
assert head_num % group_num == 0
self.hidden_dim = hidden_dim
self.head_num = head_num
self.group_num = group_num
self.head_dim = hidden_dim // head_num
self.group_head_num = head_num // group_num
self.query = nn.Linear(hidden_dim, hidden_dim)
self.key = nn.Linear(hidden_dim, group_num * self.head_dim)
self.value = nn.Linear(hidden_dim, group_num * self.head_dim)
self.output_proj = nn.Linear(hidden_dim, hidden_dim)
self.attention_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, attention_mask=None):
# x: [batch_size,seq_len,hidden_dim]
B, S, H = x.shape
q = self.query(x)
k = self.key(x)
v = self.value(x)
q = q.view(B, S, self.head_num, self.head_dim).transpose(1, 2)
k = k.view(B, S, self.group_num, self.head_dim).transpose(1, 2) # k: [batch_size,group_num,seq_len,seq_len]
v = v.view(B, S, self.group_num, self.head_dim).transpose(1, 2)
k = k.repeat(1, self.group_head_num, 1, 1) # k: [batch_size,head_num,seq_len,seq_len]
v = v.repeat(1, self.group_head_num, 1, 1)
attention_score = q @ k.transpose(-1, -2) / H ** 0.5 # attention_score: [batch_size,head_num,seq_len,seq_len]
if attention_mask is not None:
# attention_mask: [batch_size,seq_len,seq_len] -> [batch_size,head_num,seq_len,seq_len]
attention_mask = attention_mask.unsqueeze(1).repeat(1, self.head_num, 1, 1)
attention_score = attention_score.masked_fill(attention_mask == 0, float('-inf'))
attention_score = self.softmax(attention_score)
attention_score = self.attention_dropout(attention_score)
out = attention_score @ v # out: [batch_size,head_num,seq_len,head_dim]
out = out.transpose(1, 2).contiguous().view(B, S, H)
out = self.output_proj(out)
return out, attention_score
if __name__ == "__main__":
hidden_dim = 8
batch_size = 2
seq_len = 3
print_result = True
is_mask = True
# 初始化模型实例
model = GroupedQueryAttention(hidden_dim=hidden_dim)
# 生成随机输入
x = torch.randn(batch_size, seq_len, hidden_dim)
print(f"x的形状: {x.shape}")
# 前向传播
if is_mask:
mask = torch.tril(torch.ones(seq_len, seq_len))
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
print(f"mask的形状: {mask.shape}")
print(f"mask: {mask}")
out, attention_score = model(x, mask)
else:
mask = None
out, attention_score = model(x)
# 检查输出形状
assert out.shape == (batch_size, seq_len, hidden_dim), f"输出形状错误,期望 {batch_size, seq_len, hidden_dim},得到 {out.shape}"
assert attention_score.shape == (batch_size, model.head_num, seq_len, seq_len), f"注意力分数形状错误,期望 {batch_size, model.head_num, seq_len, seq_len},得到 {attention_score.shape}"
print(f"{GroupedQueryAttention.__name__} 测试通过!")
print(f"输出形状: {out.shape}")
print(f"注意力分数形状: {attention_score.shape}")
if print_result:
print("输出:")
print(out)
print("注意力分数:")
print(attention_score)