DeepSeek中的Multi-head Latent Attention技术

一、概念

        近来DeepSeek可谓是蜚声中外,而最新的DeepSeek-V3版本模型中,开发团队指出为了实现高效的推理和低成本的训练,DeepSeek-V3采用了Multi-head Latent Attention,即MLA技术。 MLA是一种创新的注意力机制,旨在显著降低推理时的显存占用和计算开销,同时保持模型性能,在论文《DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model》中被提出,本文将详细介绍相关概念。

二、核心原理

        MLA相较于传统多头注意力机制MHA的优化主要有以下几个方面:

1、低秩KV联合压缩机制

        传统MHA需要保存完整的Key和Value缓存,从而很大程度上制约了batch size的扩大和序列长度的增加。MLA 的核心思想则是通过低秩分解对注意力中的键(Key)和值(Value)进行联合压缩,以此来减小缓存占用:

        其中,是注意力层中第 t 个token的注意力输入,是keys和values对应的压缩后的latent向量,是下投影矩阵,则是keys和values的上投影矩阵。在推理的过程中,MLA只需要缓存。此外,为了减少训练时激活函数的memory,MLA对于queries也进行了低秩压缩(当然这并不会对KV缓存造成影响):

        其中,是queries对应的latent向量,分别是下投影矩阵和上投影矩阵。

2、RoPE 位置编码的解耦

        MLA 对 Rotary Positional Embedding (RoPE)进行了解耦设计。传统的 RoPE 需要对 Q 和 K 分别应用位置编码,但低秩分解KV后矩阵运算的交换律问题会导致计算复杂。MLA 通过新增独立维度保存位置信息,将 RoPE 的计算与低秩压缩解耦,既保留了位置感知能力,又避免了额外的计算开销。具体来说,解耦的过程使用额外的多头queries(定义为)和一个共享的键来支撑RoPE,则有:

        其中,分别是生成解耦queries和key的矩阵,表示连接操作,为每个头的维度,则表示解耦queries和key的每个头维度。

3、对比

        MLA 在推理时仅需缓存压缩后的latent向量和与位置相关的解耦键,而非完整的 KV 矩阵。与传统多头注意力(MHA)相比,显存占用可减少 56 倍(例如从 16384 维压缩至 576 维)。

MLA和MHA对比
 MHAMLA
KV缓存完整存储KV高维矩阵仅存储低秩latent向量
位置编码直接应用RoPE新增独立维度,解耦RoPE
显存占用长序列场景显存占用大相较MHA显存占用减小数十倍
计算复杂度高复杂度低复杂度(低秩矩阵降维)
适用场景通用长上下文场景

三、python实现

        由于DeepSeek-V3的代码是开源的,我们直接从GitHub里面把开发团队设计的MLA代码扒下来看看。

class MLA(nn.Module):
    """
    Multi-Headed Attention Layer (MLA).

    Attributes:
        dim (int): Dimensionality of the input features.
        n_heads (int): Number of attention heads.
        n_local_heads (int): Number of local attention heads for distributed systems.
        q_lora_rank (int): Rank for low-rank query projection.
        kv_lora_rank (int): Rank for low-rank key/value projection.
        qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
        qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
        qk_head_dim (int): Total dimensionality of query/key projections.
        v_head_dim (int): Dimensionality of value projections.
        softmax_scale (float): Scaling factor for softmax in attention computation.
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = args.n_heads // world_size
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim

        # 如果 q_lora_rank 为 0,直接使用 ColumnParallelLinear 将输入投影到查询空间
        if self.q_lora_rank == 0:
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            # 如果 q_lora_rank 不为 0,使用两层线性变换
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale = self.softmax_scale * mscale * mscale

        # 存储键(K)和值(V)的中间结果,以便在后续计算中复用,减少重复计算
        if attn_impl == "naive":
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        else:
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        """
        Forward pass for the Multi-Headed Attention Layer (MLA).

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
            start_pos (int): Starting position in the sequence for caching.
            freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
            mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.

        Returns:
            torch.Tensor: Output tensor with the same shape as the input.
        """
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen
        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        kv = self.wkv_a(x)
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
        if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        else:
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
        if mask is not None:
            scores += mask.unsqueeze(1)
        scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
        if attn_impl == "naive":
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
        else:
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
        x = self.wo(x.flatten(2))
        return x

 

 

### 多头潜在注意力机制概述 多头潜在注意力机制是一种用于捕捉用户不同兴趣点的方法,在推荐系统中表现出色。通过引入多个并行的注意流,该方法能够更精细地区分用户的多样化偏好[^1]。 #### 工作原理 在具体实现上,输入特征首先被映射到不同的子空间内形成查询向量(Q),键向量(K)以及值向量(V)。对于每一个单独的兴趣维度而言: - 查询向量代表当前时刻模型关注的重点; - 键向量表示所有可能的关注对象; - 值向量则承载着实际的信息内容。 接着计算Q与K之间的相似度得分,并经过softmax函数转换成概率分布形式作为权重系数应用于V之上得到加权求和后的输出O。这一过程可以表达如下公式所示: \[ O_i = \sum_j w_{ij} V_j, \quad where\;w_{ij}=softmax(\frac{Q_i K_j^\top}{\sqrt{d_k}})\] 其中\( d_k \) 是键向量的维度大小,用来缩放点积结果以稳定梯度传播。 为了进一步增强表征能力,多头设计允许同时学习一组独立但互补的兴趣视角。最终这些来自不同头部的结果会被拼接起来并通过线性变换整合为统一输出供后续层处理。 ```python import torch.nn as nn class MultiHeadLatentAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadLatentAttention, self).__init__() assert embed_size % num_heads == 0 self.embed_size = embed_size self.num_heads = num_heads self.head_dim = embed_size // num_heads self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) self.fc_out = nn.Linear(embed_size, embed_size) def forward(x): N, seq_len, _ = x.shape values = x.chunk(self.num_heads, dim=-1)[0].view(N, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) keys = x.chunk(self.num_heads, dim=-1)[1].view(N, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) queries= x.chunk(self.num_heads, dim=-1)[-1].view(N, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, seq_len, self.embed_size) return self.fc_out(out) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值