- FlashAttention最基础的方案来自使用高速的share memory来加速Softmax操作,实现Softmax的tiling方案。(Q,K,V之间的乘法可由gemm实现。)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-F2WMd8tb-1691511319949)(https://github.com/Dao-AILab/flash-attention/blob/main/assets/flashattn_banner.jpg#pic_center)]](https://i-blog.csdnimg.cn/blog_migrate/10cb900f65d5a174e41243c08c1b7f2a.png)
左侧为GPU各部分的访问速度比较
- FlashAttention使用平铺来防止大型实体化𝑁 ×𝑁 注意力矩阵(虚线框)在(相对)慢的GPU HBM上。
中间为实现过程
- softmax的计算公式

注:我也比较好奇,softmax公式怎么好像变得复杂了?我在参考文献60中找到了答案:
不幸的是,在所表示的数字范围有限的实际硬件上,算法1的第3行(求分母的时候)可能由于指数而上溢或下溢。得到这这种安全形式的改写。 - 作者提出的分解方法

右侧为融合核函数和pytorch实现的速度比较
CG
-
Jax上继承了Numpy计算加速,XLA加速,JIT编译,自动微分等,以下代码不用自己实现cuda函数Implementation of Flash Attention in Jax
-
cuda实现 https://github.com/lucidrains/flash-cosine-sim-attention/tree/main
-
https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization
-
https://github.com/kyegomez/FlashAttention20Triton
-
https://github.com/Lightning-AI/lit-llama
-
Add Flash-Attention to Huggingface Models https://github.com/conceptofmind/flash-gpt
文章介绍了FlashAttention技术,利用高速sharememory加速Softmax操作,避免在HBM上的大型矩阵运算。通过平铺和分解算法处理注意力矩阵,实现GPT-2等模型的7.6倍加速。同时提及了Jax和CUDA的实现以及相关开源项目.

4729

被折叠的 条评论
为什么被折叠?



