【DeepSeek系列】逐行讲解MLA朴素+吸收矩阵代码全新实现

引言

注意力机制作为大语言模型的核心组件,这么多年从最开始的 MHA 到现在最常用的 MQA、GQA,最主要的目的都是为了节省kv cache的大小。

MHA每一层需要存储【序列长度注意力头数每头维度】的大小,而MQA让每个头的k共享,需要存储的维度直接降低为【序列长度1每头维度】,但后面发现这样降的太多就导致性能下降,所以设计出了一种折中方案。GQA自定义多少个头共享一个k,最终维度变为【序列长度组数每头维度】

MHA、MQA、GQA对比图

以下给出了GQA的计算结构图,这里设置的组数为4,MHA和MQA就是将这个组数修改为注意力头数或1。

GQA计算结构图

MLA 借鉴了LoRA的思路,使用一个降维矩阵将隐层维度降低,然后存储为kv cache,在注意力计算时,使用一个升维矩阵将kv cache升维,从而达到节省kv cache的目的,而且由于升降维矩阵的存在,性能并不会降低(实验证明反而会提高)。

到这个时候,网上已经有很多关于MLA理论的讲解了,但 MLA 听着简单,就是注意力降维、解耦旋转矩阵、吸收矩阵,但你真得搞懂它的内部细节了吗。
MLA 内部涉及10多个矩阵,绕来绕去都晕了,每一步具体怎么切分的,怎么转化维度的,如果让你清晰的描述出来,可能也会很难吧。

本文结合网络图和代码,一步一步详细讲解MLA都做了什么,那么多矩阵都是做什么用的,还请耐心观看。

针对每一个token的注意力计算,都是一个重复的过程,那我们就取中间的一步进行模拟MLA计算。注意这里的维度大小我直接按照deepseek的参数写具体值,这样更为清晰。本文中的矩阵及向量命名都遵守deepseek的命名。

MLA 数据流向

MLA朴素版

首先介绍MLA的常规计算
MLA 朴素版数据流向图

  1. 输入:首先注意力计算 forward 函数会输入隐层向量 hidden_state,记作 h t h_t ht,它的维度是[1, 7168],因为推理时是一个token一个token进行处理的。

    还会输入 kv cache,记作 c K V c^{KV} cKV,它的维度是[n-1, 512],n-1是历史序列长度。

    由于旋转位置编码解耦,所以还要输入一个 k R k^R kR,它的维度是[n, 64],这里k的旋转位置编码在各个头是共享的,所以不需要128*64个

  2. 计算q:首先基于 h t h_t ht 计算当前 token 的 q t C q_t^C qtC q t R q_t^R qtR,即拆分成没有rope和带rope的。先将 h t h_t ht 进行降维,得到 c t Q = h t W D Q c_t^Q=h_tW^{DQ} ctQ=htWDQ,它的维度是 [1, 1536]。

    然后与 W U Q W^{UQ} WUQ 相乘,得到 q t C = c t Q W U Q q_t^C=c_t^QW^{UQ} qtC=ctQWUQ,它的维度是 [1, 128*128],代表128个头,每个头128个维度。

    同理, q t R = c t Q W Q R q_t^R=c_t^QW^{QR} qtR=ctQWQR 的维度是 [1, 128*64]。

  3. 计算c:然后将当前 token 转化成 c 作为 kv cache。

    直接将 h t h_t ht 降维,得到 c t K V = h t W D K V c_t^{KV}=h_tW^{DKV} ctKV=htWDKV,它的维度是 [1, 512]。将其与历史 kv cache 拼接,记作 c K V c^{KV} cKV,它的维度是 [n, 512]。同时将本次的 c K V c^{KV} cKV 存储下来,用于下次计算。

  4. 计算kv:处理 kv cache 即 c K V c^{KV} cKV,得到可计算的 k 和 v。

    k C = c K V W U K k^C=c^{KV}W^{UK} kC=cKVWUK,它的维度是 [n, 128*128]。

    v C = c K V W U V v^C=c^{KV}W^{UV} vC=cKVWUV,它的维度是 [n, 128*128]。

    k t R = h t W K R k_t^R=h_tW^{KR} ktR=htWKR,它的维度是 [1, 64],与输入的 k p e c a c h e k_{pe} cache kpecache 拼接到一起,得到 k t R k_t^R ktR,维度是 [n, 64],注意这里每个头之间是共享的,所以不需要128*64个。但是在后续注意力计算的时候需要维度广播,复制出128份。

  5. 计算注意力权重:

    a t t n C = q t C ( k C ) T attn^C = q_t^C(k^C)^T attnC=qtC(kC)T,它的维度是 [n, 128]。

    a t t n R = q t R ( k R ) T attn^R = q_t^R(k^R)^T attnR=qtR(kR)T,它的维度是 [n, 128]。

    a t t n = a t t n C + a t t n R attn = attn^C + attn^R attn=attnC+attnR

    a t t n _ w e i g h t = s o f t m a x ( a t t n d ) attn\_weight = softmax(\frac{attn}{\sqrt{d}}) attn_weight=softmax(d attn)

    这里带rope和不带rope的注意力是分开算的,根据矩阵的性质,分开计算再相加与合并后计算的结果是相同的。

    其等价于:

    a t t n = [ q t C ; q t R ] ( [ k C ; k R ] ) T attn = [q_t^C; q_t^R]([k^C; k^R])^T attn=[qtC;qtR]([kC;kR])T

  6. 与v相乘: a t t n _ o u t p u t = a t t n _ w e i g h t ∗ v C attn\_output = attn\_weight * v^C attn_output=attn_weightvC

  7. 最终输出: o u t p u t = a t t n _ o u t p u t ∗ W O output = attn\_output * W^{O} output=attn_outputWO

总体公式为(当前token转化成c需要单独计算,且忽略rope的部分):

o u t p u t = s o f t m a x ( ( h t W D Q W U Q ) ( c K V W U K ) T d ) ( c K V W U V ) W O = s o f t m a x ( h t W D Q W U Q W U K T c K V T d ) c K V W U V W O output = softmax(\frac{(h_tW^{DQ}W^{UQ})(c^{KV}W^{UK})^T}{\sqrt{d}}) (c^{KV}W^{UV}) W^{O} \\ = softmax(\frac{h_tW^{DQ}W^{UQ}W^{UK^T}c^{KV^T}}{\sqrt{d}})c^{KV}W^{UV}W^{O} output=softmax(d (htWDQWUQ)(cKVWUK)T)(cKVWUV)WO=softmax(d htWDQWUQWUKTcKVT)cKVWUVWO

MLA 吸收矩阵版

接下来介绍MLA吸收矩阵的计算方式
MLA 吸收矩阵数据流向图
上面总体公式中 W U Q W U K T W^{UQ}W^{UK^T} WUQWUKT 是挨着的, W U V W O W^{UV}W^{O} WUVWO 也是挨着的,所以可以提前合并成一个矩阵,记作 W U Q K W^{UQK} WUQK W U V O W^{UVO} WUVO,这样每次推理就不用进行两次矩阵运算了,加快推理速度,这个就叫做吸收矩阵 (absorb matrix)

那吸收之后的总体公式变为:

o u t p u t = s o f t m a x ( h t W D Q W U Q K c K V T d ) c K V W U V O output = softmax(\frac{h_tW^{DQ}W^{UQK}c^{KV^T}}{\sqrt{d}})c^{KV}W^{UVO} output=softmax(d htWDQWUQKcKVT)cKVWUVO

那整体计算流程就变成了:

  1. 输入:与常规相同

  2. 计算q:还是首先基于 h t h_t ht 计算当前 token 的 q t C q_t^C qtC q t R q_t^R qtR。首先还是将 h t h_t ht降维,得到 c t Q = h t W D Q c_t^Q=h_tW^{DQ} ctQ=htWDQ,它的维度是 [1, 1536]。

    q t R q_t^R qtR与常规相同: q t R = c t Q W Q R q_t^R=c_t^QW^{QR} qtR=ctQWQR 的维度是 [1, 128*64]。

    q t C q_t^C qtC直接一步到位乘以吸收矩阵: q t C = c t Q W U Q K = h t W D Q W U Q K q_t^C=c_t^QW^{UQK}=h_tW^{DQ}W^{UQK} qtC=ctQWUQK=htWDQWUQK,它的维度是 [1, 128*512]。

  3. 计算c:与常规相同

  4. 计算kv:这步删除掉处理 k C k^C kC v C v^C vC 的步骤,不需要提前分解 c K V c^{KV} cKV了,但 k R k^R kR 与常规相同。

  5. 计算注意力权重:

    a t t n C = q t C ( c K V ) T attn^C = q_t^C(c^{KV})^T attnC=qtC(cKV)T,它的维度是 [n, 128]。

    a t t n R = q t R ( k R ) T attn^R = q_t^R(k^R)^T attnR=qtR(kR)T,与常规一样,它的维度是 [n, 128]。

    a t t n = a t t n C + a t t n R attn = attn^C + attn^R attn=attnC+attnR

    a t t n _ w e i g h t = s o f t m a x ( a t t n d ) attn\_weight = softmax(\frac{attn}{\sqrt{d}}) attn_weight=softmax(d attn)

  6. 与v相乘+最终输出:两步合为一步, o u t p u t = a t t n _ w e i g h t ∗ c K V W U V O output = attn\_weight * c^{KV}W^{UVO} output=attn_weightcKVWUVO

至此,MLA就介绍完了,现在你还能复述一遍 MLA 的计算流程吗?那些矩阵还能分得清吗?如果都能搞懂,说明你真得掌握了 MLA,可以去看看 flash MLA 了[狗头]。

代码详解

这里给出MLA实现的代码,参考transformers其他模型注意力部分的实现原理,传入kv cache的同时要传出kv cache。以下代码完全按照上述MLA数据流向图设计的,与图对比基本都能看懂。

首先引入ROPE+RMSNorm 代码初始化

# 改编自:https://github.com/flashinfer-ai/flashinfer/blob/738460ff82e2230ebcc8dff50e49e1d6278e011a/tests/test_mla_decode_kernel.py
from typing import Optional, Tuple

import time
import torch
from torch import nn

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)

朴素版的MLA代码代码

class DeepseekV2AttentionVanilla(nn.Module):
    def __init__(self):
        super().__init__()
        # 以 deepseekv2 参数为准
        self.hidden_size = 7168
        self.num_heads = 128

        self.q_lora_rank = 1536
        self.qk_rope_head_dim = 64
        self.kv_lora_rank = 512
        self.v_head_dim = 128
        self.qk_nope_head_dim = 128
        self.q_head_dim = 192  # 192 = 128 + 64 = config.qk_nope_head_dim + config.qk_rope_head_dim

        self.rope_theta = 10000
        self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
        self.softmax_scale = self.q_head_dim ** (-0.5)

        # W^DQ ~ [7168, 1536]
        self.W_DQ = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
        # W^UQ ~ [1536, 128*128]
        self.W_UQ = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^QR ~ [1536, 128*64]
        self.W_QR = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_rope_head_dim, bias=False)
        # W^KR ~ [1536, 64]
        self.W_KR = nn.Linear(self.hidden_size, self.qk_rope_head_dim, bias=False)
        # W^DKV ~ [7168, 512]
        self.W_DKV = nn.Linear(self.hidden_size, self.kv_lora_rank, bias=False)
        # W^UK ~ [512, 128*128]
        self.W_UK = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^UV ~ [512, 128*128]
        self.W_UV = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False)
        # W^O ~ [128*128, 7168]
        self.W_O = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False)


    def run_decode(
        self,
        hidden_states: torch.Tensor,
        compressed_kv_normed_cache: torch.Tensor,
        k_pe_cache: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取维度:[batch_size, query_length=1, hidden_size]
        bsz, q_len, _ = hidden_states.size()
        
        # 当前step输入的x,经过W_DQ,得到 [batch_size, 1, q_lora_rank]
        c_t_Q = self.q_a_layernorm(self.W_DQ(hidden_states))
        # 再经过W_UQ,得到 [batch_size, 1, num_heads=128 * qk_nope_head_dim=128]
        q_t_C = self.W_UQ(c_t_Q)
        # 再经过W_QR,得到 [batch_size, 1, num_heads=128 * qk_rope_head_dim=64]
        q_t_R = self.W_QR(c_t_Q).view(bsz, -1, self.num_heads, self.qk_rope_head_dim)
        # 再经过W_KR,得到 [batch_size, 1, qk_rope_head_dim=64]
        # 将当前step的k_t_R添加到k_pe_cache的最后一个位置,得到新的k_pe_cache
        k_t_R = self.W_KR(hidden_states)
        k_pe_cache = torch.cat([k_pe_cache, k_t_R], dim=1)
        # 将最后一个维度拆开,方便注意力计算
        q_t_C = q_t_C.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)

        c_t_KV = self.W_DKV(hidden_states)
        compressed_kv_normed_cache = torch.cat([compressed_kv_normed_cache, c_t_KV], dim=1)
        k_C = self.W_UK(compressed_kv_normed_cache).view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
        v_C = self.W_UV(compressed_kv_normed_cache).view(bsz, -1, self.num_heads, self.v_head_dim).transpose(1, 2)
        
        # 计算位置编码,暂时不用管,与其他的RoPE计算方式类似,最终得到旋转之后的 q_pe, k_pe
        freqs_cis = precompute_freqs_cis(self.qk_rope_head_dim, compressed_kv_normed_cache.shape[1], self.rope_theta, use_scaled=False).to(q_t_R.device)
        q_t_R, k_R = apply_rotary_emb(
            q_t_R.repeat(1, compressed_kv_normed_cache.shape[1], 1, 1),
            k_pe_cache.unsqueeze(2),
            freqs_cis,
        )
        q_t_R = q_t_R[:, -1:, :, :].transpose(1, 2)
        k_R = k_R.transpose(1, 2).repeat(1, self.num_heads, 1, 1)

        attn_R = torch.matmul(q_t_R, k_R.transpose(2, 3))
        attn_C = torch.matmul(q_t_C, k_C.transpose(2, 3))

        attn_weights = (attn_R + attn_C) * self.softmax_scale
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_t_C.dtype)

        # 将注意力权重和v相乘,得到注意力输出,维度为[batch_size, num_heads, q_len, v_head_dim=128]
        attn_output = torch.matmul(attn_weights, v_C)

        # 将最后一个维度展开,得到[batch_size, num_heads, q_len, v_head_dim=128]
        attn_output = attn_output.transpose(1, 2).reshape(
            bsz, q_len, self.num_heads * self.v_head_dim
        )

        # 将注意力输出和W^O相乘,得到最终的输出,维度为[batch_size, q_len, hidden_size=7168]
        output = self.W_O(attn_output)

        return output, attn_weights, compressed_kv_normed_cache, k_pe_cache

mla_vanilla = DeepseekV2AttentionVanilla()

batch_size = 6
kv_len = 10

hidden_states = torch.randn([batch_size, 1, mla_vanilla.hidden_size])
compressed_kv_normed_cache = torch.randn([batch_size, kv_len, mla_vanilla.kv_lora_rank])
k_pe_cache = torch.randn([batch_size, kv_len, mla_vanilla.qk_rope_head_dim])

start_time = time.time()
for i in range(10):
    output_vanilla, attn_weights, compressed_kv_normed_cache, k_pe_cache = mla_vanilla.run_decode(
        hidden_states, compressed_kv_normed_cache, k_pe_cache
    )
    # print('output_vanilla.shape', output_vanilla.shape)
    # print('attn_weights.shape', attn_weights.shape)
    # print('compressed_kv_normed_cache.shape', compressed_kv_normed_cache.shape)
    # print('k_pe_cache.shape', k_pe_cache.shape)
    # print('-'*70)
end_time = time.time()
print('time', end_time - start_time)

MLA 吸收矩阵版的代码

from torch import nn
class DeepseekV2AttentionMatAbsorbDecode(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_size = 7168
        self.num_heads = 128

        self.q_lora_rank = 1536
        self.qk_rope_head_dim = 64
        self.kv_lora_rank = 512
        self.v_head_dim = 128
        self.qk_nope_head_dim = 128
        self.q_head_dim = 192  # 192 = 128 + 64 = config.qk_nope_head_dim + config.qk_rope_head_dim

        self.rope_theta = 10000
        self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
        self.softmax_scale = self.q_head_dim ** (-0.5)

        # W^DQ ~ [7168, 1536]
        self.W_DQ = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
        # W^UQ ~ [1536, 128*128]
        self.W_UQ = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^QR ~ [1536, 128*64]
        self.W_QR = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_rope_head_dim, bias=False)
        # W^KR ~ [1536, 64]
        self.W_KR = nn.Linear(self.hidden_size, self.qk_rope_head_dim, bias=False)
        # W^DKV ~ [7168, 512]
        self.W_DKV = nn.Linear(self.hidden_size, self.kv_lora_rank, bias=False)
        # W^UK ~ [512, 128*128]
        self.W_UK = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^UV ~ [512, 128*128]
        self.W_UV = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False)
        # W^O ~ [128*128, 7168]
        self.W_O = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False)

        # 由于nn.Linear初始化的时一个对象,权重矩阵只是对象中的一个类,没法直接两个矩阵相乘
        # 所以需要用.weight来取出来,而且因为.weight的维度与初始化是反的,所以需要用t()来转置
        # W_UQ_absorb ~ [1536, 128, 128]
        W_UQ_absorb = self.W_UQ.weight.t().view(self.q_lora_rank, self.num_heads, self.qk_nope_head_dim)
        # W_UK_absorb ~ [512, 128, 128]
        W_UK_absorb = self.W_UK.weight.t().view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim)
        # W_UV_absorb ~ [512, 128, 128]
        W_UV_absorb = self.W_UV.weight.t().view(self.kv_lora_rank, self.num_heads, self.v_head_dim)
        # W_O_absorb ~ [7168, 128, 128]
        W_O_absorb = self.W_O.weight.view(self.hidden_size, self.num_heads, self.v_head_dim)

        # 吸收矩阵:将W_UQ和W_UK合并,得到新的W_UQK,维度为[1536, 128, 128]
        # q~q_lora_rank  n~num_heads  d~qk_nope_head_dim  l~kv_lora_rank
        # 这里把n当做batch_size,也就是矩阵相乘不会影响的那个维度,矩阵qd与dl相乘,得到ql,加上刚才的n,所以得到qnl
        # 再将其flatten展平,得到[1536, 65536]
        self.W_UQK = torch.einsum("q n d, l n d -> q n l", W_UQ_absorb, W_UK_absorb).flatten(start_dim=1)
        # 吸收矩阵,将W_UV和W_O合并,得到新的W_UV_O,维度为[128, 512, 7168]
        # l~kv_lora_rank  n~num_heads  d~v_head_dim  h~hidden_size
        # 这里把n当做batch_size,也就是矩阵相乘不会影响的那个维度,矩阵ld与dh相乘,得到lh,加上刚才的n,并把n放到最前面,所以得到nlh
        # 再将其flatten展平,得到[65536, 7168]
        self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV_absorb, W_O_absorb).flatten(start_dim=0, end_dim=1)

    def run_decode(
        self,
        hidden_states: torch.Tensor,
        compressed_kv_normed_cache: torch.Tensor,
        k_pe_cache: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取维度:[batch_size, query_length=1, hidden_size]
        bsz, q_len, _ = hidden_states.size()

        # 当前step输入的x,经过W_DQ,得到 [batch_size, 1, q_lora_rank]
        c_t_Q = self.q_a_layernorm(self.W_DQ(hidden_states))
        # 再经过W_UQ,得到 [batch_size, 1, num_heads=128 * qk_nope_head_dim=128]
        q_t_C = torch.matmul(c_t_Q, self.W_UQK)
        # 再经过W_QR,得到 [batch_size, 1, num_heads=128 * qk_rope_head_dim=64]
        q_t_R = self.W_QR(c_t_Q).view(bsz, -1, self.num_heads, self.qk_rope_head_dim)
        # 再经过W_KR,得到 [batch_size, 1, qk_rope_head_dim=64]
        # 将当前step的k_t_R添加到k_pe_cache的最后一个位置,得到新的k_pe_cache
        k_t_R = self.W_KR(hidden_states)
        k_pe_cache = torch.cat([k_pe_cache, k_t_R], dim=1)
        # 将最后一个维度拆开,方便注意力计算
        q_t_C = q_t_C.view(bsz, q_len, self.num_heads, self.kv_lora_rank).transpose(1, 2)

        c_t_KV = self.W_DKV(hidden_states)
        compressed_kv_normed_cache = torch.cat([compressed_kv_normed_cache, c_t_KV], dim=1)
        
        # 计算位置编码,暂时不用管,与其他的RoPE计算方式类似,最终得到旋转之后的 q_pe, k_pe
        freqs_cis = precompute_freqs_cis(self.qk_rope_head_dim, compressed_kv_normed_cache.shape[1], self.rope_theta, use_scaled=False).to(q_t_R.device)
        q_t_R, k_R = apply_rotary_emb(
            q_t_R.repeat(1, compressed_kv_normed_cache.shape[1], 1, 1),
            k_pe_cache.unsqueeze(2),
            freqs_cis,
        )
        q_t_R = q_t_R[:, -1:, :, :].transpose(1, 2)
        k_R = k_R.transpose(1, 2).repeat(1, self.num_heads, 1, 1)

        attn_R = torch.matmul(q_t_R, k_R.transpose(2, 3))
        attn_C = torch.matmul(q_t_C, compressed_kv_normed_cache.unsqueeze(1).transpose(2, 3))

        attn_weights = (attn_R + attn_C) * self.softmax_scale
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_t_C.dtype)

        # attn_weights * c^KV * W^UVO
        attn_output = torch.matmul(
            attn_weights.squeeze(2),  # [bsz, 128, kv_len]
            compressed_kv_normed_cache,  # [bsz, kv_len, 512]
        ).reshape(bsz, self.num_heads * self.kv_lora_rank)
        output = torch.matmul(attn_output, self.W_UV_O,)  # W_UV_O ~ [65536, 7168]

        return output, attn_weights, compressed_kv_normed_cache, k_pe_cache
        

bsz = 6
kv_len = 10

hidden_states = torch.randn([bsz, 1, 7168])
compressed_kv_normed_cache = torch.randn([bsz, kv_len, 512])
k_pe_cache = torch.randn([bsz, kv_len, 64])

mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode()
start_time = time.time()
for i in range(10):
    output_vanilla, attn_weights, compressed_kv_normed_cache, k_pe_cache = mla_mat_absorb.run_decode(
        hidden_states, compressed_kv_normed_cache, k_pe_cache
    )
    # print('output_vanilla.shape', output_vanilla.shape)
    # print('attn_weights.shape', attn_weights.shape)
    # print('compressed_kv_normed_cache.shape', compressed_kv_normed_cache.shape)
    # print('k_pe_cache.shape', k_pe_cache.shape)
    # print('-'*70)
end_time = time.time()
print('time', end_time - start_time)

计算量对比

其实两个矩阵吸收之后的维度是比两个矩阵相加的参数量是多的,但是在计算时由于n的存在,n越大,吸收矩阵的计算量越小,所以加速越多。

实验表明,n<=25时吸收前计算量更低,n>=26吸收后计算量更低,但实际使用prompt的长度都不止26,所以吸收之后肯定会变快。

# 实际吸收后的矩阵更大了,计算量更多了,但由于n的存在,序列越长,总体计算量越小
n=20000
W_UQK = 1536*128*128 + 512*128*128*n + 128*128*n
W_UQK_absorbed = 1536*128*512 + 128*512*n
W_UV_O = 512*128*128*n + 128*128*n + 128*128*7168
W_UV_O_absorbed = 128*512*7168 + 128*512*n
print('W_UQK吸收前:', W_UQK)
print('W_UQK吸收后:', W_UQK_absorbed)
print('W_UV_O吸收前:', W_UV_O)
print('W_UV_O吸收后:', W_UV_O_absorbed)
print('全部吸收前:', W_UQK + W_UV_O)
print('全部吸收后:', W_UQK_absorbed + W_UV_O_absorbed)

# W_UQK吸收前: 168125005824
# W_UQK吸收后: 1411383296
# W_UV_O吸收前: 168217280512
# W_UV_O吸收后: 1780482048
# 全部吸收前: 336342286336
# 全部吸收后: 3191865344

参考资料

  1. https://arxiv.org/pdf/2405.04434
  2. https://kexue.fm/archives/10091
  3. https://mp.weixin.qq.com/s/E7NwwMYw14FRT6OKzuVXFA
  4. https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py#L682
  5. https://github.com/flashinfer-ai/flashinfer/blob/738460ff82e2230ebcc8dff50e49e1d6278e011a/tests/test_mla_decode_kernel.py
### DeepSeek MoE 架构 DeepSeek-V3 采用了名为 DeepSeekMoE 的架构来改进前馈网络 (FFNs)[^1]。这种架构相较于传统模型如 GShard,在专家设计上更为精细,部分专家被隔离为共享组件。这使得计算资源能够得到更加高效的分配和利用。 #### 技术特点 - **细粒度专家**:相比于其他 MoE 架构,DeepSeekMoE 中的专家数量更多但规模较小。 - **资源共享**:某些特定类型的专家可以在不同层间共享,从而减少冗余并提高效率。 - **性能提升**:通过上述优化措施,整体计算成本降低的同时保持甚至提升了预测精度[^2]。 ```python import torch.nn as nn class FineGrainedExpert(nn.Module): def __init__(self, input_size, output_size): super(FineGrainedExpert, self).__init__() self.fc = nn.Linear(input_size, output_size) def forward(self, x): return self.fc(x) ``` ### 多头潜意识注意力MLA)机制 为了进一步增强模型的表现力,DeepSeek 结合了多头潜意识注意力(Multi-head Latent Attention, MLA)技术。该方法允许各个注意力建立不同的局部视窗来进行特征提取,并且可以集成 RoPE(旋转位置嵌入)以处理序列数据中的相对距离信息[^3]。 #### 关键要素 - **多头结构**:多个独立运作的关注模块共同作用于输入序列的不同片段。 - **灵活窗口大小**:各头部可以根据任务需求调整其关注范围内的元素数目。 - **RoPE 编码支持**:引入旋转位置编码方式帮助捕捉长期依赖关系而不受绝对位置影响。 ```python import math from typing import Optional def apply_rotary_pos_emb(q, k, rope): seq_len = q.shape[-2] cos = rope.cos(seq_len).to(q.device) sin = rope.sin(seq_len).to(q.device) q_embedded = (q * cos) + (rotate_half(q) * sin) k_embedded = (k * cos) + (rotate_half(k) * sin) return q_embedded, k_embedded def rotate_half(x): batch_size, num_heads, seq_length, head_dim = x.size() half_head_dim = head_dim // 2 first_part = x[:, :, :, :half_head_dim] second_part = x[:, :, :, half_head_dim:] rotated_x = torch.cat((-second_part, first_part), dim=-1) return rotated_x ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值