通俗易懂的KVcache图解

在分享之前先提出三个问题:

1. 为什么KVCache不保存Q

2. KVCache如何减少计算量

3. 为什么模型回答的长度不会影响回答速度?

本文将带着这3个问题来详解KVcache

KVcache是什么

kv cache是指一种用于提升大模型推理性能的技术,通过缓存注意力机制中的键值(Key-Value)对来减少冗余计算,从而提高模型推理的速度。

不懂Self Attention的可以先去看这篇文章:

原因

首先要知道大模型进行推理任务时,是一个token一个token进行输出的。

例:给GPT一个任务 “对这个句子进行扩充:我爱“

GPT的输出为:

我爱

我爱中

我爱中国

我爱中国美

我爱中国美食

我爱中国美食,

我爱中国美食,因

我爱中国美食,因为

我爱中国美食,因为它

我爱中国美食,因为它好

我爱中国美食,因为它好吃

我爱中国美食,因为它好吃。

通过这个例子可以看出它生成句子是按token输出的(为了方便理解,假设一个字为一个token)。输出的token会与输入的tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符后结束。自回归任务中,token只能和之前的文本做attention计算。

KVcache图解原理

将这个prompt通过embedding生成QKV三个向量。

“我”只能对自己做attention。得到Z_1

“爱”的Q向量对“我”和“爱”的K向量进行计算后再对V进行加权求和算得新向量后输出Z_{\text{2}}

输入到模型后得到新的token“中国”

重复上述过程

可以发现 在此过程中,新token只与之前token的KV有关系,和之前的Q没关系,因此可以将之前的KV进行保存,就不用再次计算。这就是KVcache。

问题回答

问题1:为什么不保存Q

因为每次运算只有当前token的Q向量,之前token的Q根本不需要计算,所以缓存Q没意义。

问题2:KVCache如何减少计算量

减少的就是不用重复计算之前token的KV向量,但是每个新词的Attention还得计算。

问题3:每次推理过程的输入tokens都变长了,为什么推理FLOPs不随之增大而是保持恒定呢?

因为使用了KVcache导致第i+1 轮输入数据只比第i轮输入数据新增了一个token,其他全部相同!因此第i+1轮推理时必然包含了第 i 轮的部分计算。

代码实现

这是自己实现的一个简单的多头注意力机制+KVcache

具体如何实现多头注意力机制可以看这篇文章:

import torch
import torch.nn as nn
import math

class MyMultiheadAttentionKV(nn.Module):
    def __init__(self, hidden_dim: int = 1024, heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.heads_num = heads
        self.dropout = nn.Dropout(dropout)
        self.head_dim = hidden_dim // self.heads_num
        self.Wq = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.Wk = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.Wv = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.outputlayer = nn.Linear(self.hidden_dim, self.hidden_dim)

    def forward(self, x, mask=None, key_cache=None, value_cache=None):
        # x = (batch_size, seq_len, hidden_dim)
        query = self.Wq(x)
        key = self.Wk(x)
        value = self.Wv(x)
        bs, seq_len, _ = x.size()

        # Reshape to (batch_size, heads_num, seq_len, head_dim)
        query = query.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        
        key = key.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        value = value.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        # Cache key and value if provided
        if key_cache is not None and value_cache is not None:
            key = torch.cat([key_cache, key], dim=2)  # Append along sequence dimension
            value = torch.cat([value_cache, value], dim=2)
        

        # Update caches
        key_cache = key
        value_cache = value

        # Calculate attention scores
        score = query @ key.transpose(-1, -2) / math.sqrt(self.head_dim)  # (batch_size, heads_num, seq_len, seq_len)
        if mask is not None:
            # Mask size should match the updated sequence length after cache concatenation
            mask = mask[:, :, :key.size(3), :key.size(3)]  # Crop the mask to the new size

        
        score = torch.softmax(score, dim=-1)
        score = self.dropout(score)
        output = score @ value  # (batch_size, heads_num, seq_len, head_dim)

        # Reshape back to (batch_size, seq_len, hidden_dim)
        output = output.transpose(1, 2).contiguous().view(bs, seq_len, -1)
        output = self.outputlayer(output)

        return output, key_cache, value_cache  # Return output and updated caches

 测试代码

def test_kvcache():
    
    torch.manual_seed(42)

   
    batch_size, seq_len, hidden_dim, heads_num = 3000, 100, 128, 8
    x = torch.rand(batch_size, seq_len, hidden_dim)  # Random input data
    attention_mask = torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))  # Attention mask
    

    net = MyMultiheadAttentionKV(hidden_dim, heads_num)
    

    output, key_cache, value_cache = net(x, attention_mask)
    

    new_x = torch.rand(batch_size, seq_len, hidden_dim)  
    output, key_cache, value_cache = net(new_x, attention_mask, key_cache, value_cache)


    third_x = torch.rand(batch_size, seq_len, hidden_dim) 
    output, key_cache, value_cache = net(third_x, attention_mask, key_cache, value_cache)
    
    print(f"Output shape: {output.shape}")
    print(f"Key cache shape: {key_cache.shape}")
    print(f"Value cache shape: {value_cache.shape}")
    


# Run the test
if __name__ == "__main__":
    test_kvcache()

使用KVcache后:

其实,KV Cache 配置开启后,推理过程可以分为2个阶段:

  1. 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。
  2. 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

总结

KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。

参考:https://zhuanlan.zhihu.com/p/630832593

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值