论文:FlashAttention: Fast and Memory-Efficient Exact Attention
with IO-Awareness
0. 引言
Flash Attention的主要目的是加速和节省内存,主要贡献包括:
- 计算softmax时候不需要全量input数据,可以分段计算;
- 反向传播的时候,不存储attention matrix ( N 2 N^2 N2 的矩阵),而是只存储softmax归一化的系数。
1. 动机 & 计算步骤
不同硬件模块之间的带宽和存储空间有明显差异,例如下图中左边的三角图,最顶端的是GPU中的SRAM(片上存储),它的容量非常小但是带宽非常大,以A100 GPU为例,它有108个流式多核处理器,每个处理器上的片上SRAM大小只有192KB,因此A100总共的SRAM大小是 192 K B × 108 = 20 M B 192KB\times 108 = 20MB 192KB×108=20MB,但是其吞吐量能高达19TB/s。
而A100 GPU HBM(High Bandwidth Memory高带宽内存也就是我们常说的GPU显存大小)大小在40GB~80GB左右,但是带宽只与1.5TB/s
FlashAttention的主要动机就是希望把SRAM利用起来,但是难点就在于SRAM太小了,一个普通的矩阵乘法都放不下去。Flash Attention 的解决思路就是将计算模块进行分解,拆成一个个小的计算任务。

下图给出了标准的注意力机制的实现流程,可以看到因为HBM的大小更大,我们平时写pytorch代码的时候最常用到的就是HBM,所以对于HBM的读写操作非常频繁,而SRAM利用率反而不高。

外循环(Outer Loop):

内循环 (Inner Loop):

内外循环如何配合:

内外循环优化的核心优势:

2. Softmax Tiling
在介绍具体的计算算法前,我们首先需要了解一下Softmax Tiling。
(1)数值稳定
Softmax包含指数函数,所以为了避免数值溢出问题,可以将每个元素都减去最大值,如下图示,最后计算结果和原来的Softmax是一致的。
m ( x ) : = max i x i m(x):=\max_{i} ~ x_i m(x):=imax xi
f ( x ) : = [ e x 1 − m ( x ) … e x B − m ( x ) ] \ f(x):=\left[\begin{array}{llll}e^{x_{1}-m(x)} & \ldots & e^{x_{B}-m(x)}\end{array}\right] f(x):=[ex1−m(x)…exB−m(x)]
ℓ ( x ) : = ∑ i f ( x ) i softmax ( x ) : = f ( x ) ℓ ( x ) \ell(x):=\sum_{i} f(x)_{i} \ \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} ℓ(x):=i∑f(x)i softmax(x):=ℓ(x)f(x)
(2)分块计算softmax
因为Softmax都是按行计算的,所以我们考虑一行切分成两部分的情况,即原本的一行数据 x ∈ R 2 B = [ x ( 1 ) , x ( 2 ) ] x \in \mathbb{R}^{2 B}=\left[x^{(1)}, x^{(2)}\right] x∈R2B=[x(1),x(2)]:

可以看到计算不同块的
f
(
x
)
f(x)
f(x)值时,乘上的系数是不同的,但是最后化简后的结果都是指数函数减去了整行的最大值。以
x
(
1
)
x^{(1)}
x(1) 为例:

这里几个需要解释的地方::
-
最大值 m ( x ) m(x) m(x) 的合并:
在 softmax 的计算中,为了数值稳定性,我们引入了最大值归一化技巧(即减去最大值)。对分块 softmax 的情况,假设输入被分为两个块 x ( 1 ) x^{(1)} x(1) 和 x ( 2 ) x^{(2)} x(2),则全局最大值 m ( x ) m(x) m(x) 可以写为:

m ( x ( 1 ) ) m(x^{(1)}) m(x(1)) 和 m ( x ( 2 ) ) m(x^{(2)}) m(x(2)) 是两个块中各自的最大值。 m ( x ) m(x) m(x) 是整个输入的全局最大值。这是为了确保在计算 softmax 时减去的最大值是正确的全局值。 -
分块函数 f ( x ) f(x) f(x) 的重写:
定义分块后的 f ( x ) f(x) f(x) 为 softmax 的分子部分。对于分块 softmax,分子可以写为:

为什么可以这样写?
每个块 f ( x ( 1 ) ) f(x^{(1)}) f(x(1)) 和 f ( x ( 2 ) ) f(x^{(2)}) f(x(2)) 是分别计算的部分分子值:

其中 e m ( x ( 1 ) ) − m ( x ) e^{m(x^{(1)})-m(x)} em(x(1))−m(x) 和 e m ( x ( 2 ) ) − m ( x ) e^{m(x^{(2)})-m(x)} em(x(2))−m(x) 是用来将分块的最大值 m ( x ( 1 ) ) m(x^{(1)}) m(x(1)) 和 m ( x ( 2 ) ) m(x^{(2)}) m(x(2)) 对齐到全局最大值 m ( x ) m(x) m(x)
更直观地可以将 f ( x ( 1 ) ) = ∑ i ∈ b l o c k 1 e x i − m ( x ( 1 ) ) f(x^{(1)})=\sum_{i \in block_1}e^{x_i-m(x^{(1)})} f(x(1))=∑i∈block1exi−m(x(1)) 代入 f ( x ) f(x) f(x) 表达式的第一项:
e m ( x ( 1 ) ) − m ( x ) e x i − m ( x ( 1 ) ) = e x i − m ( x ) e^{m(x^{(1)})-m(x)}e^{x_i-m(x^{(1)})}=e^{{x_i}-m(x)} em(x(1))−m(x)exi−m(x(1))=exi−m(x)
通过这种形式化,计算可以直接在块内完成,而不需要处理全局范围的输入。
-
归一化因子 ℓ ( x ) \ell(x) ℓ(x) 的合并
softmax 的归一化因子 ℓ ( x ) \ell(x) ℓ(x) 是分母部分,表示分子部分的总和:

对于分块的情况, ℓ ( x ) \ell(x) ℓ(x) 可以写成:

ℓ ( x ( 1 ) ) \ell(x^{(1)}) ℓ(x(1)) 和 ℓ ( x ( 2 ) ) \ell(x^{(2)}) ℓ(x(2)) 是块内的归一化因子(对应分块的 ∑ e x i \sum e^{x_i} ∑exi)e m ( x ( 1 ) ) − m ( x ) e^{m(x^{(1)})-m(x)} em(x(1))−m(x) 和 e m ( x ( 2 ) ) − m ( x ) e^{m(x^{(2)})-m(x)} em(x(2))−m(x) 是将每块的归一化因子从局部最大值归一化到全局最大值 m ( x ) m(x) m(x)
结合分子 f ( x ) f(x) f(x) 和分母 ℓ ( x ) \ell(x) ℓ(x) 最终的 softmax 可以写为:

这种形式化推导了分块 softmax 的完整表达式,并允许逐块计算后再合并结果,避免了直接操作全局矩阵的高内存需求。
3. 算法流程
FlashAttention旨在避免从 HBM(High Bandwidth Memory)中读取和写入注意力矩阵,这需要做到:
-
目标一:在不访问整个输入的情况下计算softmax函数的缩减;将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减。
-
目标二:在反向传播中不能存储中间注意力矩阵。标准Attention算法的实现需要将计算过程中的S、P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型,因此Flash Attention就提出了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。
FlashAttention算法流程如下图所示:

为方便理解,下图将FlashAttention的计算流程可视化出来了,简单理解就是每一次只计算一个block的值,通过多轮的双for循环完成整个注意力的计算。

1970

被折叠的 条评论
为什么被折叠?



