一、引言
论文: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
作者: Stanford University
代码: FlashAttention
特点: 该方法提出将Q、K、V拆分为若干小块,使执行注意力时不需要频繁进行读写操作,而是每个小块只进行一次读写,从而提升注意力的执行速度。
⚠️ 在学习该方法前,建议补充Attention的相关知识。
二、详情
GPU中SRAM和HBM的计算和存储能力如下图:
可见,SRAM计算能力强(17TB/s),HBM的存储容量大(40GB)。因此,GPU的运算通常在SRAM上进行,如果运算结果的内存占用太大,系统会把运算结果先写入HBM,然后从HBM读出来再在SRAM上进行下一步的运算。
于是,我们就得到原始Attention的执行过程:
其中,Q、K、V分别是Query、Key、Value矩阵,S是相似度矩阵,P是权重矩阵,O是输出矩阵。
这里没写除以 d k \sqrt{d_k} dk的操作,不过无伤大雅,因为它对运算的影响并不大。
可见,计算S、P、O时都要进行读取,计算完成后也都要进行写入。然而,运算速度领先于读写速度导致SRAM运算完了要等数据过来才能进行下一步运算,这就拖慢了整体的速度。
2.1 拆分
FlashAttention提出将Q、K、V拆分成若干小块,这样每个小块的S、P矩阵不至于太大到需要写入HBM中,这样就能只在最开始读取Q、K、V、O(之前的运算结果),在SRAM中完成所有运算后,再将新的O写入HBM。
如果没有SoftMax操作,该过程很容易实现,如下图:
分别循环Q和K、V的小块,循环结果求和就是我们所有期望的O。但是,SoftMax阻碍了它的实现,回顾原始SoftMax公式:
s o f t m a x ( s ) j = e s j ∑ k = 1 N e s k softmax(\boldsymbol{s})_j=\frac{e^{s_j}}{\sum_{k=1}^{N}e^{s_k}} softmax(s)j=∑k=1Neskesj
可见,它要把相似度矩阵S的每一行转为一个概率分布。但是分块策略无法一次性获得完整的S中的行,于是FlashAttention在SoftMax中引入了 m ( s ) m(\boldsymbol{s}) m(s),新的SoftMax公式如下:
s o f t m a x ( s ) i = e s i − m ( s ) ∑ j = 1 N e s j − m ( s ) = f i l ( s ) softmax(\boldsymbol{s})_i=\frac{e^{s_i-m(\boldsymbol{s})}}{\sum_{j=1}^{N}e^{s_j-m(\boldsymbol{s})}}=\frac{f_i}{l(\boldsymbol{s})} softmax(s)i=∑j=1Nesj−m(s)esi−m(s)=l(s)fi
其中,最大值 m ( s ) = max i s i m(\boldsymbol{s})=\max_i s_i m(s)=maxisi,指数和 l ( s ) = ∑ i f i l(\boldsymbol{s})=\sum_i f_i l(s)=∑ifi。事实上,该操作不会影响SoftMax的结果,如下:
s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) = [ e 1 e 1 + e 2 + e 3 + e 10 , e 2 e 1 + e 2 + e 3 + e 10 , e 3 e 1 + e 2 + e 3 + e 10 , e 10 e 1 + e 2 + e 3 + e 10 ] = [ e 1 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 2 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 3 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 10 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 ] softmax([1,2,3,10])=[\frac{e^{1}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{2}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{3}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{10}}{e^{1}+e^{2}+e^{3}+e^{10}}]\\=[\frac{e^{1-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{2-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{3-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{10-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}] softmax([1,2,3,10])=[e1+e2+e3+e10e1,e1+e2+e3+e10e2,e1+e2+e3+e10e3,e1+e2+e3+e10e10]=[e1−10+e2−10+e3−10+e10−10e1−10,e1−10+e2−10+e3−10+e10−10e2−10,e1−10+e2−10+e

最低0.47元/天 解锁文章
344

被折叠的 条评论
为什么被折叠?



