本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。
原文链接:FasterTransformer Decoding 源码分析(六)-CrossAttention介绍
GiantPandaCV | FasterTransformer Decoding 源码分析(一)-整体框架介绍-优快云博客
GiantPandaCV | FasterTransformer Decoding 源码分析(二)-Decoder框架介绍-优快云博客
GiantPandaCV | FasterTransformer Decoding 源码分析(三)-LayerNorm介绍-优快云博客
GiantPandaCV | FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍-优快云博客
GiantPandaCV | FasterTransformer Decoding 源码分析(五)-AddBiasResidualLayerNorm介绍-优快云博客
作者丨进击的Killua
来源丨https://zhuanlan.zhihu.com/p/670739629
编辑丨GiantPandaCV
本文是FasterTransformer Decoding源码分析的第六篇,笔者试图去分析CrossAttention部分的代码实现和优化。由于CrossAttention和SelfAttention计算流程上类似,所以在实现上FasterTransformer使用了相同的底层Kernel函数,因此会有大量重复的概念和优化点,重复部分本文就不介绍了,所以在阅读本文前务必先浏览进击的Killua:FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍这篇文章,一些共性的地方会在这篇文章中做统一介绍,本文着重介绍区别点。
一、模块介绍
如下图所示,CrossAttention模块位于DecoderLayer的第4个模块,输入为经过LayerNorm后的SelfAttention结果和encoder的outputs,经过该模块处理后进行残差连接再输入LayerNorm中。
CrossAttention在decoder中的位置
CrossAttention模块本质上还是要实现如下几个公式,主要的区别在于其中 CrossAttention 的K, V矩阵不是使用 上一个 Decoder block的输出或inputs计算的,而是使用Encoder 的编码信息矩阵计算的,这里还是把公式放出来展示下。
crossAttention 公式
二、设计&优化
整体Block和Thread的执行模型还是和SelfAttention的保持一致,这里不再赘述,主要介绍一下有一些区别的KV Cache。
1. KV Cache
由于在CrossAttention中K,V矩阵是来自于已经计算完成的Encoder输出,所以KV Cache的程度会更大,即第一次运算把KV计算出来之后,后续只要读取Cache即可,不需要用本step的输入再进行线性变换得到增量的部分K,V,如下图所示。
三、源码分析
1. 方法入口
CrossAttention的调用入口如下,解释下这里的输入和输出,具体逻辑在后面。
输入Tensor
-
input_query:normalize之后的SelfAttention输出,大小是[batch_size,hidden_units_]
-
encoder_output: encoder模块的输出,大小是[batch_size, mem_max_seq_len, memory_hidden_dimension]
-
encoder_sequence_length:每个句子的长度,大小是[batch_size]
-
finished: 解码是否结束的标记,大小是[batch_size]
-
step: 当前解码的步数
输出Tensor
-
hidden_features:CrossAttention的输出feature,大小是[batch_size,hidden_units_],和input_query大小