GiantPandaCV | FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍

GiantPandaCV | FasterTransformer Decoding 源码分析(一)-整体框架介绍-优快云博客

GiantPandaCV | FasterTransformer Decoding 源码分析(二)-Decoder框架介绍-优快云博客

GiantPandaCV | FasterTransformer Decoding 源码分析(三)-LayerNorm介绍-优快云博客

作者丨进击的Killua

来源丨https://zhuanlan.zhihu.com/p/669648527

编辑丨GiantPandaCV

本文是FasterTransformer Decoding源码分析的第四篇,也是该系列文章中最核心的一篇。笔者试图去分析selfAttention部分的代码实现和优化,内容较多也比较复杂,笔者会尽最大努力把原理阐述清楚。

一、模块介绍

如下图所示,SelfAttention模块位于DecoderLayer的第二个模块,输入为decoder inputs经过LayerNorm正则化后的结果,经过该模块处理后进行残差连接再输入LayerNorm中。SelfAttention可以简单理解为decoder中对 inputs进行编码生成feature的模块,在后面的流程中会和encoder产生的feature进行crossAttention生成最终的结果。

SelfAttention在decoder中的位置

那么SelfAttention模块本质上就是要实现如下几个公式,这里简单介绍下这几个公式,分别是:

  1. 线性化生成Q、K、V。

  2. 矩阵乘(Q*KT)得到attention Score 。

  3. 对attention Score进行softmax化得到logits。

  4. 使用logits和V进行点乘,再线性化获得最终输出。

SelfAttention 公式

对这几个公式不清楚的可以去看李宏毅老师的讲解视频,每个步骤做了非常详细的介绍。

二、设计&优化

我们先来看下FasterTransformer中针对SelfAttention这个模块设计了哪些优化策略,再来看代码是如何实现的。

1. KV Cache

我们知道在Decoder中解码是逐步进行的,先来看下原始的解码步骤,如下图和文字描述。

  1. step=1,输入= [s], 输出 = 我

  2. step=2,输入= [s] 我, 输出 = [s] 我 有

  3. step=3,输入= [s] 我 有, 输出 = [s] 我 有 猫

  4. step=4,输入= [s] 我 有 猫, 输出 = [s] 我 有 猫 [e]

Decoder 逐步解码过程

因此在逐步解码过程中,针对SelfAttention中Q,K,V矩阵的生成过程如下图所示。(这里仅展示了deocder中首层的SelfAttention,每层的计算逻辑类似)

原始的QKV生成过程

通过观察这个计算过程我们可以发现,每个step的结果中绝大部分的数值都被上个step计算得到过,如下图红框所示。

原始的QKV生成过程,大量重复计算

因此我们可以将每个step中计算过的结果缓存起来,在后续步骤中可以跳过这些内容的计算,只计算增量部分的内容。优化后的计算流程如下图所示,每个step只需计算本次新输入词的Q,K,V,大部分key和value结果均来自前面step计算结果的缓存,这就是KV cache的设计原理,也是经典的空间换时间的优化方法。该例子中仅包含3个step,实际解码过程可能包含上千个step,因此可以节省大量的矩阵计算量,代价就是需要更多的全局内存空间。

优化后的QKV生成过程

2. Cache Layout

根据第一点优化设计,有了Cache后每个step都需要Load Key Cache,FasterTransformer设计了高效的存储layout来支持多轮数据的读写,由于场景上是多读少写(每个key写入一次,需要在多个step中读取),所以设计的初衷是通过牺牲一些写入的效率来最大化读取效率。我们看下Key Cache的shape是:

[num_layer, batch, head_num, size_per_head // x, max_seq_len, x]

这里x是根据数据类型来确定的,比如处理的数据类型是FP32类型(4 bytes),则x=4,即保证最后一个维度的size是16 bytes。为什么要保证16 bytes?因为在很多情况下GPU 的全局内存对齐要求是 128 位(16 bytes),这样首先访问是对齐的,其次同个warp的多个相邻线程可以进行内存联合访问,这样就可以提高存储的访问速度。

num_layer,batch和head_num这前三个维度参数含义比较明确,在具体的核函数内部执行时其为固定值,现对后三个维度的使用和优化进行说明。假设 x=4, max_seq_len=6, size_per_head=8,即后三维是[2, 6, 4],如下图所示。对于一个warp中的线程:

写场景

其在生成了当前词的key后需要将其写入cache中,由thread1负责前16B写入,thead2负责后16B写入,由于中间隔了96B,所以无法做写入合并。

读场景

在需要读入cache中的历史key时,thread1首次循环负责读入第一个key的前16B,thread2首次循环负责读入第二个key的前16B,这两个读请求地址是连续的可以合并请求;同理第二轮循环thread1和thread2也可以合并请求,由此最大化了读取的效率,降低整体耗时。

key cache中 read/write过程

3. Block和Thread设计

我们再来从block和thread的视图来看下是如何实现selfAttention中的公式的。

block视图

每个block负责的运算任务如下图所示,其只负责一个词(即当前需要解码的新词)在一个head中一个step的kqv运算,输出output。

block计算视图

thread视图

具体到每个thread,它会负责该词query化后和 当前key + cache中的某些历史key进行kqv运算,最后在block维度上进行全局归约。

thread计算视图

三、源码分析

1. 方法入口

SelfAttention的调用入口如下,代码,解释下这里的输入和输出,具体逻辑在后面。

输入Tensor

  1. input_query:normalize之后的decoder_input,大小是[batch_size,hidden_units_]

  2. finished: 解码是否结束的标记,大小是[batch_size]

  3. sequence_lengths: 每个句子的长度,大小是[batch_size]

  4. step: 当前解码的步数

  5. cache_indirection(option):记录了解码到当前句子中每个词在前序步骤中的beam_index

输出Tensor

  1. hidden_features: SelfAttention的输出feature,大小是[batch_size,hidden_units_],和input_query大小一致。

  2. key_cache: SelfAttention中存储key的cache,用于后续step的计算。

  3. value_cache: SelfAttention中存储Value的cache,用于后续step的计算。

   // input tensors:    //      decoder_input [batch_size, hidden_dimension],    //      encoder_output [batch_size, mem_max_seq_len, memory_hidden_dimension],    //      encoder_sequence_length [batch_size],    //      finished [batch_size],    //      step [1] on cpu    //      sequence_lengths [batch_size]    //      cache_indirection [local_batch_size / beam_width, beam_width, max_seq_len]    // output tensors:    //      decoder_output [batch_size, hidden_dimension],    //      key_cache [num_layer, batch, head_num, size_per_head // x, max_seq_len, x]    //      value_cache [num_layer, batch, head_num, max_seq_len, size_per_head]    //      key_mem_cache [num_layer, batch_size, mem_max_seq_len, hidden_dimension],    //      value_mem_cache [num_layer, batch_size, mem_max_seq_len, hidden_dimension] 
       TensorMap self_attention_input_tensors{
            {"input_query", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, decoder_normed_input_}},
            {"finished", input_tensors->at(3)},
            {"sequence_lengths", input_tensors->at(5)},
            {"step", input_tensors->at(4)}};        
        self_attention_input_tensors.insertIfValid("cache_indirection", input_tensors->at(6));


        TensorMap self_attention_output_tensors{
            {"hidden_features", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, self_attn_output_}},
            {"key_cache",
             Tensor{MEMORY_GPU,
                    data_type,
                    std::vector<size_t>(output_tensors->at(1).shape.begin() + 1, output_tensors->at(1).shape.end()),
                    output_tensors->at(1).getPtrWithOffset(self_key_cache_offset)}},
            {"value_cache",
             Tensor{MEMORY_GPU,
                    data_type,
                    std::vector<size_t>(output_tensors->at(2).shape.begin() + 1, output_tensors->at(2).shape.end()),
                    output_tensors->at(2).getPtrWithOffset<T>(self_value_cache_offset)}}};


        self_attention_layer_->forward(&self_attention_output_tensors,
                                       &self_attention_input_tensors,
                                       &decoder_layer_weight->at(l).self_attention_weights);

2. 主体框架

主体框架代码由三部分构成,分别是该step的QKV生成、output生成和Linear输出,详见代码。其中第一部分和第三部分都使用了cublas的封装矩阵乘方法gemm,这里就不多介绍了,主要功能逻辑在第二部分output生成。

第一部分:QKV生成

公式里需要做三次乘法,这里直接用了一次矩阵乘就把QKV的结果都生成了,原理是将权重矩阵concat起来再做乘法ÿ

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值