长上下文训练的关键因素(2)-flash-attention

Flash-Attention是一种优化深度学习中注意力机制的方法,旨在解决长上下文训练的计算复杂度和内存消耗问题。通过靠近GPU计算、减少数据在HBM与SRAM间传输、省内存等手段,它提高了训练效率,降低了内存墙的影响。Flash-Attention采用矩阵分块策略,避免了全矩阵运算,实现了内存使用和IO操作的减少,从而在大多数情况下将内存复杂度从n²降低到n。

上一篇在这 长上下文训练的关键因素(1) (qq.com)

     我看有读者留言说想看Flash-attention,那么今天就讲它

     这东西算法其实全讲老复杂了,我们挑着好理解方式讲

     我们上节课讲到了计算复杂度的关系,就是以下这部分,先来复习一下

     QKV都是由self-attention,也就是由输入数据self出来的,输入的数据[B,n,D]的最后一维和Wq,Wk,Wv三个矩阵是相等的。

     Wq矩阵=[D,D],Wk和Wv也都一样, 然后输入数据分别和Wq,Wk,Wv点乘出来QKV3个值。

      拿Q举例,输入数据[B,n,D]要和q矩阵[D,D]点积,生成出来的东西是[B,n,D],K和V也一样。

       然后按照公式来计算

图片

       Q*K的转置,就是[B,n,D]*[B,n,D],就是[B,n,n]

       如果考虑到多头注意力因素就是[B,header_number,n,head_dim]*

   [B,header_number,n,head_dim]=[B,header_number,n,n],都差不多

    不管是不是多头,点乘的计算量都是B*n^2*D,D固定的,B是batch_size,也不用考虑,所以计算量相关性最大的变数就剩下n^2了,n越大,肯定平方值越大,复杂度越高,这个就是算力复杂度和n的平方相关的由来。

     至于后面QK的softmax结果和V相乘也是一样的,但是因为他俩是2次的顺序计算,所以还是和n的平方相关,也是transformer最被诟病的地方。

      然后上一篇结尾我也说了,Flash-attention和别的实现厂商长下文的,实现思路不一样,那它为啥不一样呢?

      首先说它干了什么事

      1. 离GPU的计算资源最近

      2.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值