没想到制约attention、transformer效率的竟然不只是QVT矩阵,softmax往往也占据大量计算量
回顾下softmax
Softmax函数公式为
它的时间复杂度主要取决于输入向量的长度n。对于一个长度为n的向量,Softmax函数需要对每一个元素进行指数运算(时间复杂度为O(1)),然后求和(时间复杂度为O(n)),最后再对每个元素进行除法操作(时间复杂度为O(n))。因此,整个Softmax函数的时间复杂度是O(n)。
然而,如果考虑到指数运算在数值上可能非常大或非常小,实际的计算中可能需要进行一些额外的处理,比如减去最大值以防止溢出或下溢,这将稍微增加计算的复杂度,但在大O记号下仍然可以认为是O(n)。
attention下softmax时间复杂度
减去缩放因子后的简化attention公式为:
对一个 1×𝑛 的行向量进行 Softmax,时间复杂度是 𝑂(𝑛),但是对一个 𝑛×𝑛 矩阵的每一行做一个 Softmax,时间复杂度就是 𝑂(𝑛2)
如果没有 Softmax,那么 Attention 的公式就变为三个矩阵连乘 𝑄𝐾⊤𝑉,而矩阵乘法是满足结合率的,所以我们可以先算 𝐾⊤𝑉,得到一个 𝑑×𝑑 的矩阵(这一步的时间复杂度是 𝑂(𝑑2𝑛)),然后再用 𝑄 左乘它(这一步的时间复杂度是 𝑂(𝑑2𝑛)),由于 𝑑≪𝑛,所以这样算大致的时间复杂度只是 𝑂(𝑛)
举例,对于 BERT base 来说,𝑑=64 而不是 768,因为 768 实际上是通过 Multi-Head 拼接得到的,而每个单独 head 的 𝑑=64
也就是说,如果能去掉 Softmax 的 Attention 复杂度可以降到最理想的线性级别 𝑂(𝑛)!