【notes】注意力和KV Cache

你可以把 Q、K、V 想象成图书馆的检索系统:Key 是书的目录标签,Value 是书的内容,Query 是你的查询关键词。你想「注意」哪些书,取决于你的 Query 和图书馆里的 Key 的相似度。
假设输入x∈RT×dmodelx \in R^{T \times d_{model}}xRT×dmodel

待补充

Q:为什么qkv要分别乘以三个不同的矩阵?
A:qkv使用不同的矩阵可以学习到不同的投影子空间。

Q:为什么要乘以 1dmodel\frac{1}{\sqrt{d_{model}}}dmodel1
A:为了防止注意力分数过大导致梯度消失。注意力机制中,点积运算的结果可能随着维度 dmodeld_{model}dmodel 增大而变得过大,通过除以 dmodel\sqrt{d_{model}}dmodel 可以稳定梯度,确保训练过程的稳定性。

多头注意力实现代码

import torch
import torch.nn as nn
class MHA(nn.Module):
    def __init__(self, d_model=512,num_heads=8):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.head_dim=d_model//num_heads
        self.scale=self.head_dim**(-0.5)
        self.wq=nn.Linear(d_model,d_model)
        self.wk=nn.Linear(d_model,d_model)
        self.wv=nn.Linear(d_model,d_model)
        self.wo=nn.Linear(d_model,d_model)
    def forward(self,q,k,v,mask=None):
        #q,k,v: [batch_size,seq_len,d_model]
        batch_size=q.size(0)
        q=self.wq(q).view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)
        k=self.wk(k).view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)
        v=self.wv(v).view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)
        #q,k,v: [batch_size,num_heads,seq_len,head_dim]
        attn=torch.matmul(q,k.transpose(-2,-1))*self.scale
        if mask is not None:
            attn=attn.masked_fill(mask==0,float('-inf'))
        attn=attn.softmax(dim=-1)
        out=torch.matmul(attn,v)
        out=out.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)
        out=self.wo(out)
        return out

在训练阶段,模型可以同时看见整个序列,因此可以一次性并行地计算所有位置的 Q、K、V。而在推理阶段,生成是自回归的:每次只能生成下一个 token。对于新生成的 token,需要以它为 Query,和所有已生成的上下文做注意力计算。这时,Key 和 Value 代表的是之前所有位置的表示,需要保留住这些历史信息。使用 KV cache,就是在每步推理时,把之前步骤里计算好的 Key 和 Value 缓存起来,避免重复前向传播,从而高效地完成下一个 token 的生成。
如下图,随便画的,q1=k1=v1=x1,生成x2保存,然后q2=x2,k2=v2=[x1,x2],生成x3保存。如果不保存,就跟训练一样每次都要算全部的,复杂度是O(n2)O(n^2)O(n2),有了kv cache复杂度就降到了O(n)O(n)O(n)。每次新的 Query 只要和之前缓存下来的所有 K、V 做注意力就能生成下一个token,而不用重复计算之前的K、V。
随便画了一下

MQA:让多头共享一组Key和Value
GQA:分组共享Key和Value
在这里插入图片描述
PagedAttention
vLLM 使用 PagedAttention 以提高推理效率。传统的分段式实现会为每个请求分配一段连续的虚拟内存,需要预先估计可能的最大上下文长度。如果实际输入比预估长度短,就会浪费未使用的内存;而不同请求的长度又可能变化较大,导致内存分段难以复用,从而降低整体吞吐量。PagedAttention 的关键在于将注意力缓存的内存布局按页(page)组织,每页通常包含 16 个 token。请求的 KV 缓存在虚拟内存中按页进行分配和寻址,相同的内容可以复用同一个内存页,不同内容再分配新的页。这样做能够灵活管理不规则的上下文长度,减少碎片和浪费,从而最大化内存利用率并提升大批量请求的吞吐能力。

缓存清理策略(Eviction Policy)
在多任务或长序列的场景下,KV Cache可能占用大量内存。引入缓存清理策略,可以根据需要清理部分缓存,释放内存。常见策略:
LRU(最近最少使用):清理最长时间未使用的缓存。
LFU(最少使用频率):清理使用频率最低的缓存。

缓存合并(Merging)
不直接删除KV缓存,而是通过合并相似的缓存来减少存储需求。例如,将相邻时间步的k和v进行合并,降低缓存规模。

进一步减少内存:KV Cache量化、局部注意力限制窗口步长、流式LLM:保留序列中的首几个token(因为发现它们通常对生成有全局影响)和最近的几个token,丢弃中间的不太重要的token。

import torch
import torch.nn as nn

class KVCacheAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(KVCacheAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.kv_cache = None

    def forward(self, query, key, value):
        if self.kv_cache is not None:
            key = torch.cat([self.kv_cache['key'], key], dim=0)
            value = torch.cat([self.kv_cache['value'], value], dim=0)
        self.kv_cache = {'key': key, 'value': value}
        output, _ = self.attention(query, key, value)
        return output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值