FlashAttention

FlashAttention是一种针对Transformer模型的优化方法,通过tiling和recomputation技术,实现了快速且内存高效的精确注意力计算。该方法避免了全N*N矩阵的加载,分块计算softmax并存储部分和,减少了内存交通,特别是在处理长序列时能提升模型训练速度和准确性。在反向传播中,不需要保存中间结果P,从而降低了内存成本。

Sources

paper: https://arxiv.org/abs/2205.14135

an informal talk by the author Tri Dao: https://www.youtube.com/watch?v=FThvfkXWqtE

code repo: GitHub - HazyResearch/flash-attention: Fast and memory-efficient exact attention 

introduction to transformer: Transformer and Bert_EverNoob的博客-优快云博客

Key Takeaways

Key features of FlashAttention (in the paper title): fast, memory-efficient, exact

Key algorithmic ideas

  • tiling: avoid loading/unloading N*N matrices at once ==> one key here is how to break down softmax into linearly recoverable partials sums;
  • recomputation: avoid saving intermediate results, compute them again during back-propagation instead.

The upshot: faster model training, better models with longer sequences (longer context for more accurate parsing).

Attention to FlashAttention

Regular Attention:

the heavy memory access is marked below:

a full version of the algorithm is:

the mask and dropout stages are not necessary, but they can be used to enhance training results; both of these perf. opt. steps introduce additional heavy memory costs. 

to reduce memory traffic, we want to:

1. compute softmax piecewise per tiling

2. skip saving intermediate results (S, P)

which is the most memory demanding being N*N

 We apply two established techniques (tiling, recomputation) to overcome the technical challenge of computing exact attention in sub-quadratic HBM accesses. We describe this in Algorithm 1. The main idea is that we split the inputs Q, K, V into blocks, load them from slow HBM to fast SRAM, then compute the attention output with respect to those blocks.

Tiling

By scaling the output of each block by the right normalization factor before adding them up, we get the correct result at the end.

 there is a mathematically equivalent parital sum:

and we only need to record max and partial exp. sum per block to break the interdependence among the blocks.

Recomputation

for backward pass:

the backward pass for standard attention is listed below, with the avoidable memory costs boxed out: 

notably for FlashAttention, we do not save P

how to compute the gradients with saved softmax normalization statistics is discussed below.

FlashAttention Forward and Backward Implmentation

Memory Efficient Forward Pass

Memory Efficient Backward Pass

==> the ":" denotes row-wise; the hollow dot denotes pointwise multiplication.

from here

The Jacobian matrix collects all first-order partial derivatives of a multivariate function that can be used for backpropagation. The Jacobian determinant is useful in changing between variables, where it acts as a scaling factor between one coordinate space and another. 

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值