1.1 GPU 硬件特点
由于 FlashAttention 计算 self-attention 的主要关键是有效的硬件使用,所以了解GPU内存和各种操作的性能特征是很有必要的。
以 A100 (40GB HBM) 为例,下面显示其内存层次结构的粗略图。SRAM内存分布在108个流式多处理器(SMs)上,每个处理器192KB。片上SRAM比HBM快得多,但比HBM小得多,在计算方面,使用Tensor Core的BFLOAT16 的理论峰值吞吐量为 312 TFLOPS。GPU 的典型操作方式是使用大量的线程来执行一个操作,这个操作被称为内核。输入从HBM加载到寄存器和SRAM,并在计算后写回HBM。
算法对于内存带宽的需求通常使用 计算强度 (arithmetic intensity) 来表示,单位是 OPs/byte。意思是在算法中平均每读入单位数据,能支持多少次运算操作。它有助于理解操作的瓶颈,即计算约束(Compute-bound)或带宽约束(Bandwidth-bound, or Memory-bound)。
- 算力 π :也称为计算平台的性能上限,指的是一个计算平台倾尽全力每秒钟所能完成的浮点运算数。单位是
FLOPS
orFLOP/s
。 - 带宽 β :也即计算平台的带宽上限,指的是一个计算平台倾尽全力每秒所能完成的内存交换量。单位是
Byte/s
。 - 计算强度上限 Imax=πβ :两个指标相除即可得到计算平台的计算强度上限。它描述的是在这个计算平台上,单位内存交换最多用来进行多少次计算。单位是
FLOPs/Byte
。 - 模型的理论性能 P:我们最关心的指标,即模型_在计算平台上_所能达到的每秒浮点运算次数(理论值)。单位是
FLOPS
orFLOP/s
。
如下图所示,Roof-line 描述了模型在一个计算平台的限制下,到底能达到多快的浮点计算速度,即算力决定“屋顶”的高度(绿色线段),带宽决定“房檐”的斜率(红色线段)。
Roof-line 划分出的两个瓶颈区域,即
- 计算约束——此时HBM访问所花费的时间相对较低,不管模型的计算强度 I 有多大,它的理论性能 P 最大只能等于计算平台的算力π。例如,具有较大内维数的矩阵乘法和具有大量通道的卷积。
- 带宽约束——当模型的计算强度 I 小于计算平台的计算强度上限 Imax 时,由于此时模型位于“房檐”区间,因此模型理论性能 P 的大小完全由计算平台的带宽上限 β(房檐的斜率)以及模型自身的计算强度 I 所决定。例如,elementwise 操作 (如activation, dropout 等) 和 规约操作 (如sum, softmax, batch normalization, layer normalization等)。
在 self-attention 中,计算速度比内存速度快得多,因此进程(操作)越来越多地受到内存(HBM)访问的瓶颈。因此,FlashAttention论文的目标是尽可能高效地使用SRAM来加快计算速度
1.2 FlashAttention 的核心思想及细节推导
首先我们回顾一下标准 Attention 的操作:
其中 S,P (对于decoder来说还有 mask)的空间复杂度都是 O(N2) ,另外还有几个带宽约束的操作:对 S 的 scale, mask 和 softmax 操作,对 P 的 dropout 操作。下图算法展示了 HBM 与 SRAM 之间的数据传输过程。
1.2.1 FlashAttention 的优化思路
从上面的分析可以看出,O(N2)复杂度的矩阵对HBM及其重复读写是一个主要瓶颈。要解决这个问题,需要做两件主要的事情:
- 在不访问整个输入的情况下计算 softmax
- 不为反向传播存储大的中间 attention 矩阵
为此 FlashAttention 提出了两种方法来分布解决上述问题:tiling 和 recomputation。
- tiling - 注意力计算被重新构造,将输入分割成块,并通过在输入块上进行多次传递来递增地执行softmax操作。
- recomputation - 存储来自前向的 softmax 归一化因子,以便在反向中快速重新计算芯片上的 attention,这比从HBM读取中间矩阵的标准注意力方法更快。
由于重新计算,这确实导致FLOPs增加,但是由于大量减少HBM访问,FlashAttention运行速度更快(在GPT-2上高达7.6倍)。下面将详细推导 FlashAttention 的细节。
该算法背后的主要思想是分割输入,将它们从慢速HBM加载到快速SRAM,然后计算这些块的 attention 输出。在将每个块的输出相加之前,将其按正确的归一化因子进行缩放,从而得到正确的结果。
下面将分别讨论 tiling 和 recomputation 的正确性:
1.2.2 Tiling 与前向计算
上述算法中除了线性操作和 elementwise 外,分块计算注意力的关键部分是 softmax 的分块计算。向量的 softmax 可以计算为
m ( x ) : = max i x i f ( x ) : = [ e x 1 − m ( x ) … e x B − m ( x ) ] ℓ ( x ) : = ∑ i f ( x ) i softmax ( x ) : = f ( x ) ℓ ( x ) m(x):=\max _i x_i\\ f(x):=\left[\begin{array}{lll} e^{x_1-m(x)} & \ldots & e^{x_B-m(x)} \end{array}\right]\\ \ell(x):=\sum_i f(x)_i\\ \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}\\ m(x):=imaxxif(x):=[ex1−m(x)…exB−m(x)]ℓ(x):=i∑f(x)isoftmax(x):=ℓ(x)f(x)
其中 x 可分解为 x=[x(1)x(2)]∈R2B , x(1),x(2)∈RB 那么则有
m ( x ) = m ( [ x ( 1 ) x ( 2 ) ] ) = max ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) m(x)=m\left(\left[x^{(1)} x^{(2)}\right]\right)=\max \left(m\left(x^{(1)}\right), m\left(x^{(2)}\right)\right)\\ m(x)=m([x(1)x(2)])=max(m(x(1)),m(x(2)))
那么可以通过如下构造的方式,使得 f(x) 的结果与分块前保持统一:
f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] = [ e m ( x ( 1 ) ) − m ( x ) [ e x 1 ( 1 ) − m ( x ( 1 ) ) … e x B ( 1 ) − m ( x ( 1 ) ) ] e m ( x ( 2 ) ) − m ( x ) [ e x 1 ( 2 ) − m ( x ( 2 ) ) … e x B ( 2 ) − m ( x ( 2 ) ) ] ] = [ [ e x 1 ( 1 ) − m ( x ) … e x B ( 1 ) − m ( x ) ] [ e x 1 ( 2 ) − m ( x ) … e x B ( 2 ) − m ( x ) ] ] = [ e x 1 − m ( x ) … e x B − m ( x ) ] \begin{aligned} f(x)&=\left[\begin{array}{ll} e^{m\left(x^{(1)}\right)-m(x)} f\left(x^{(1)}\right) & e^{m\left(x^{(2)}\right)-m(x)} f\left(x^{(2)}\right) \end{array}\right] \\ &=\left[\begin{array}{ll} e^{m\left(x^{(1)}\right)-m(x)} \left[\begin{array}{lll} e^{x_1^{(1)}-m(x^{(1)})} & \ldots & e^{x_B^{(1)}-m(x^{(1)})} \end{array}\right] & e^{m\left(x^{(2)}\right)-m(x)} \left[\begin{array}{lll} e^{x_1^{(2)}-m(x^{(2)})} & \ldots & e^{x_B^{(2)}-m(x^{(2)})} \end{array}\right] \end{array}\right] \\ &= \left[\begin{array}{ll} \left[\begin{array}{lll} e^{x_1^{(1)}-m(x)} & \ldots & e^{x_B^{(1)}-m(x)} \end{array}\right] & \left[\begin{array}{lll} e^{x_1^{(2)}-m(x)} & \ldots & e^{x_B^{(2)}-m(x)} \end{array}\right] \end{array}\right] \\ &=\left[\begin{array}{lll} e^{x_1-m(x)} & \ldots & e^{x_B-m(x)} \end{array}\right] \end{aligned} f(x)=[em(x(1))−m(x)f(x(1))em(x(2))−m(x)f(x(2))]=[em(x(1))−m(x)[ex1(1)−m(x(1))…exB(1)−m(x(1))]em(x(2))−m(x)[ex1(2)−m(x(2))…exB(2)−m(x(2))]]=[[ex1(1)−m(x)