Flash Attention 是 由 Tri Dao 和 Dan Fu 等人在2022年的论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 中 提出的,
论文可以从 https://arxiv.org/abs/2205.14135 页面下载,点击 View PDF 就可以下载。
下面我们通过详细解读这篇论文,来说明什么是Flash Attention。
Transformer在处理长序列时速度慢且占用大量内存,因为自注意力的时间和内存复杂度与序列长度的平方成正比。近似注意力方法尝试通过牺牲模型质量来减少计算复杂度来解决这个问题,但往往不能实现实际速度提升。我们认为一个缺失的原则是使注意力算法具有IO感知能力——考虑GPU内存层级之间的读取和写入。我们提出了FlashAttention,这是一种IO感知的精确注意力算法,使用平铺技术来减少GPU高带宽内存(HBM)和GPU片上静态随机存储器(SRAM)之间的内存读写次数。我们分析了FlashAttention的IO复杂度,表明它需要比标准注意力更少的HBM访问,并且在一定范围内的SRAM大小下是最优的。FlashAttention训练Transformer比现有基准更快:与MLPerf 1.1训练速度记录相比,BERT-large(序列长度为512)的端到端墙钟速度提高了15%,GPT-2(序列长度为1K)的速度提高了3倍,长距离竞技场(序列长度为1K-4K)的速度提高了2.4倍。
左图:FlashAttention使用平铺技术,防止在(相对)较慢的GPU HBM上生成大型的𝑁×𝑁注意力矩阵(虚线框)。在外部循环(红色箭头)中,FlashAttention循环遍历K和V矩阵的块,并将它们加载到快速的片上