softmax介绍和attention下的时间复杂度

没想到制约attention、transformer效率的竟然不只是QVT矩阵,softmax往往也占据大量计算量

回顾下softmax

Softmax函数公式为

它的时间复杂度主要取决于输入向量的长度n。对于一个长度为n的向量,Softmax函数需要对每一个元素进行指数运算(时间复杂度为O(1)),然后求和(时间复杂度为O(n)),最后再对每个元素进行除法操作(时间复杂度为O(n))。因此,整个Softmax函数的时间复杂度是O(n)

然而,如果考虑到指数运算在数值上可能非常大或非常小,实际的计算中可能需要进行一些额外的处理,比如减去最大值以防止溢出或下溢,这将稍微增加计算的复杂度,但在大O记号下仍然可以认为是O(n)。

attention下softmax时间复杂度

减去缩放因子$\frac{1}{\sqrt{d}}$后的简化attention公式为:

$\operatorname{Attention}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})=\operatorname{Softmax}\left(\boldsymbol{Q} \boldsymbol{K}^{\top}\right) \boldsymbol{V}$

对一个 1×𝑛 的行向量进行 Softmax,时间复杂度是 𝑂(𝑛),但是对一个 𝑛×𝑛 矩阵的每一行做一个 Softmax,时间复杂度就是 𝑂(𝑛2)

如果没有 Softmax,那么 Attention 的公式就变为三个矩阵连乘 𝑄𝐾⊤𝑉,而矩阵乘法是满足结合率的,所以我们可以先算 𝐾⊤𝑉,得到一个 𝑑×𝑑 的矩阵(这一步的时间复杂度是 𝑂(𝑑2𝑛)),然后再用 𝑄 左乘它(这一步的时间复杂度是 𝑂(𝑑2𝑛)),由于 𝑑≪𝑛,所以这样算大致的时间复杂度只是 𝑂(𝑛)

举例,对于 BERT base 来说,𝑑=64 而不是 768,因为 768 实际上是通过 Multi-Head 拼接得到的,而每个单独 head 的 𝑑=64

也就是说,如果能去掉 Softmax 的 Attention 复杂度可以降到最理想的线性级别 𝑂(𝑛)!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值