上一篇在这 长上下文训练的关键因素(1) (qq.com)
我看有读者留言说想看Flash-attention,那么今天就讲它
这东西算法其实全讲老复杂了,我们挑着好理解方式讲
我们上节课讲到了计算复杂度的关系,就是以下这部分,先来复习一下
QKV都是由self-attention,也就是由输入数据self出来的,输入的数据[B,n,D]的最后一维和Wq,Wk,Wv三个矩阵是相等的。
Wq矩阵=[D,D],Wk和Wv也都一样, 然后输入数据分别和Wq,Wk,Wv点乘出来QKV3个值。
拿Q举例,输入数据[B,n,D]要和q矩阵[D,D]点积,生成出来的东西是[B,n,D],K和V也一样。
然后按照公式来计算

Q*K的转置,就是[B,n,D]*[B,n,D],就是[B,n,n]
如果考虑到多头注意力因素就是[B,header_number,n,head_dim]*
[B,header_number,n,head_dim]=[B,header_number,n,n],都差不多
不管是不是多头,点乘的计算量都是B*n^2*D,D固定的,B是batch_size,也不用考虑,所以计算量相关性最大的变数就剩下n^2了,n越大,肯定平方值越大,复杂度越高,这个就是算力复杂度和n的平方相关的由来。
至于后面QK的softmax结果和V相乘也是一样的,但是因为他俩是2次的顺序计算,所以还是和n的平方相关,也是transformer最被诟病的地方。
然后上一篇结尾我也说了,Flash-attention和别的实现厂商长下文的,实现思路不一样,那它为啥不一样呢?
首先说它干了什么事
1. 离GPU的计算资源最近
2.

Flash-Attention是一种优化深度学习中注意力机制的方法,旨在解决长上下文训练的计算复杂度和内存消耗问题。通过靠近GPU计算、减少数据在HBM与SRAM间传输、省内存等手段,它提高了训练效率,降低了内存墙的影响。Flash-Attention采用矩阵分块策略,避免了全矩阵运算,实现了内存使用和IO操作的减少,从而在大多数情况下将内存复杂度从n²降低到n。
最低0.47元/天 解锁文章
3932

被折叠的 条评论
为什么被折叠?



