一、概念
近来DeepSeek可谓是蜚声中外,而最新的DeepSeek-V3版本模型中,开发团队指出为了实现高效的推理和低成本的训练,DeepSeek-V3采用了Multi-head Latent Attention,即MLA技术。 MLA是一种创新的注意力机制,旨在显著降低推理时的显存占用和计算开销,同时保持模型性能,在论文《DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model》中被提出,本文将详细介绍相关概念。
二、核心原理
MLA相较于传统多头注意力机制MHA的优化主要有以下几个方面:
1、低秩KV联合压缩机制
传统MHA需要保存完整的Key和Value缓存,从而很大程度上制约了batch size的扩大和序列长度的增加。MLA 的核心思想则是通过低秩分解对注意力中的键(Key)和值(Value)进行联合压缩,以此来减小缓存占用:
其中,是注意力层中第 t 个token的注意力输入,
是keys和values对应的压缩后的latent向量,
是下投影矩阵,
则是keys和values的上投影矩阵。在推理的过程中,MLA只需要缓存
。此外,为了减少训练时激活函数的memory,MLA对于queries也进行了低秩压缩(当然这并不会对KV缓存造成影响):
其中,是queries对应的latent向量,
分别是下投影矩阵和上投影矩阵。
2、RoPE 位置编码的解耦
MLA 对 Rotary Positional Embedding (RoPE)进行了解耦设计。传统的 RoPE 需要对 Q 和 K 分别应用位置编码,但低秩分解KV后矩阵运算的交换律问题会导致计算复杂。MLA 通过新增独立维度保存位置信息,将 RoPE 的计算与低秩压缩解耦,既保留了位置感知能力,又避免了额外的计算开销。具体来说,解耦的过程使用额外的多头queries(定义为)和一个共享的键
来支撑RoPE,则有:
其中,分别是生成解耦queries和key的矩阵,
表示连接操作,
为每个头的维度,
则表示解耦queries和key的每个头维度。
3、对比
MLA 在推理时仅需缓存压缩后的latent向量和与位置相关的解耦键
,而非完整的 KV 矩阵。与传统多头注意力(MHA)相比,显存占用可减少 56 倍(例如从 16384 维压缩至 576 维)。
MHA | MLA | |
---|---|---|
KV缓存 | 完整存储KV高维矩阵 | 仅存储低秩latent向量 |
位置编码 | 直接应用RoPE | 新增独立维度,解耦RoPE |
显存占用 | 长序列场景显存占用大 | 相较MHA显存占用减小数十倍 |
计算复杂度 | 低复杂度(低秩矩阵降维) | |
适用场景 | 通用 | 长上下文场景 |
三、python实现
由于DeepSeek-V3的代码是开源的,我们直接从GitHub里面把开发团队设计的MLA代码扒下来看看。
class MLA(nn.Module):
"""
Multi-Headed Attention Layer (MLA).
Attributes:
dim (int): Dimensionality of the input features.
n_heads (int): Number of attention heads.
n_local_heads (int): Number of local attention heads for distributed systems.
q_lora_rank (int): Rank for low-rank query projection.
kv_lora_rank (int): Rank for low-rank key/value projection.
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
qk_head_dim (int): Total dimensionality of query/key projections.
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
# 如果 q_lora_rank 为 0,直接使用 ColumnParallelLinear 将输入投影到查询空间
if self.q_lora_rank == 0:
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
# 如果 q_lora_rank 不为 0,使用两层线性变换
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
# 存储键(K)和值(V)的中间结果,以便在后续计算中复用,减少重复计算
if attn_impl == "naive":
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
else:
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Headed Attention Layer (MLA).
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
if mask is not None:
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
if attn_impl == "naive":
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x