从核函数角度深度理解线性注意力Linear Attention

1、映射函数

        当我们比较x与y的相似性时,通常将其从低维度映射到高维特征空间,以便在其中用简单的点积来衡量原始空间中复杂的相似性

         映射函数为\phi\left ( \cdot \right )

\phi : R^d \rightarrow R ^D \left ( D\geq d \right )

        例如:对于数据X = \left ( x_1,x_2 \right ) \in R^2

        \phi\left ( X \right ) = \phi\left ( x_1,x_2 \right ) = \left (x_1^2,x_2^2,\sqrt{2}x_1x_2,\sqrt{2}x_1,\sqrt{2}x_2,x_1x_2, 1 \right ) \in R^6

        就将2维数据变成了6维数据了

2、核函数

        常规而言,计算xy的相似度,应该先分别计算\phi\left ( x \right )\phi\left ( y \right ),再计算\phi\left ( x \right ) \cdot \phi\left ( y \right )就是xy的高维相似度了。但是这样计算太麻烦了,我们希望直接把值带入进去就可以直接知道\phi\left ( x \right ) \cdot \phi\left ( y \right )的值。

        换句话说,我们不需要显式知道\phi\left ( x \right )\phi\left ( y \right )到底等于什么,只需要一个能计算最终相似度的公式就可以了。

        于是核函数为\kappa \left (x \cdot y\right )

\kappa \left (x \cdot y\right ) = <\phi\left ( x \right ) \cdot \phi\left ( y \right ) >

        例如:如果有映射函数\phi\left ( \cdot \right )

 \phi\left ( x \right ) = \left ( x^2,\sqrt{2}x,1 \right )

        将1维数据升为3维,我们可以计算对于xy在映射空间中的相似度:

\phi\left ( x \right ) \cdot \phi\left (y \right )^T = \begin{pmatrix} x^2,\sqrt{2}x, 1 \end{pmatrix}\cdot \begin{pmatrix} y^2\\ \sqrt{2}y \\ 1 \end{pmatrix} = x^2y^2+2xy+1 =(xy+1)^2

        所以可得该情况下的核函数:

\kappa \left (x \cdot y\right ) =(xy+1)^2

        以上是显式可推导而出的核函数公式,也就是直接带入xy就可以。妙处在于,我们无需知道显式的映射\phi\left ( x \right )\phi\left ( y \right )到底等于什么,也无需在高维空间中进行计算,就能得到高维空间的内积结果。

3、注意力公式与核函数

        对于注意力公式Attn = softmax(\frac {Q\cdot K^T}{\sqrt{d_{model}}})V而言,我们先省略缩放参数d_{model},取第i个查询Q_i,则原公式为:

Attn_i = softmax(Q_i\cdot K^T)V

        进一步我们对所有的K,V做softmax的展开为:

Attn_i = \frac{\sum _{j=1}^{N} {exp(Q_i\cdot K_j^T)\cdot V_j}}{\sum _{j=1}^{N} {exp(Q_i\cdot K_j^T)}}

        当我们审视标准注意力公式时,会发现其核心是计算查询Q_i与键K_j的相似度,我们可以将exp(Q\cdot K^T)视为一种特殊的核函数。

        然而,这个核函数对应的特征映射是复杂且隐式的,虽然可以直接带入Q_iK_j求出来,但是直接Q\cdot K^T是进行矩阵乘法,复杂度o(n^2)随着n增大暴增,所以我们回归核函数定义的本质,将核函数还原为\kappa \left (q \cdot k\right ) = <\phi\left ( q \right ) \cdot \phi\left ( k \right ) >,即先分别计算\phi\left ( q \right )\phi\left ( k \right ),再计算\phi\left ( q \right ) \cdot \phi\left ( k \right ),改为多个o(n)的串行计算得到exp(Q\cdot K^T),从而避免复杂度o(n^2)的矩阵乘法

        线性注意力的核心思想是:能否设计一种新的、更简单的核函数\phi(Q\cdot K^T),并找到其对应的显式特征映射\phi(\cdot ),使得相似度计算可以写成\phi(Q\cdot K^T)\approx \phi(Q)\cdot \phi(K)^T

4、进一步公式推导

        首先在这里,exp(Q\cdot K^T)不能直接进行分解,因为是指数函数

        举例为:

\phi(xy)=exp(xy)=(exp(x))^y=(\phi(x))^y \neq \phi(x)\cdot \phi(y)^T

        所以我们只能去另外找近似的新的核函数(已经不是exp了),使得:

sim(Q\cdot K^T)\approx \phi(Q)\cdot \phi(K)^T

        由此,原式子可以化为:

LinearAttn_i = \frac{\sum _{j=1}^{N} {\phi(Q_i)\phi( K_j)^T\cdot V_j}} {\sum _{j=1}^{N} {\phi(Q_i)\phi( K_j)^T}}

        同时,因为Q_i只与i有关,求和符号只与j有关,所以可以将\phi(Q_i)提出求和符号,放在前面:

LinearAttn_i = \frac{\phi(Q_i)\sum _{j=1}^{N} {\phi( K_j)^T\cdot V_j}} {\phi(Q_i)\sum _{j=1}^{N} {\phi( K_j)^T}}

        我们可以将其进一步写作向量的形式,结合矩阵乘法的交换律优先计算\phi(K)^TV

(\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV)

5、总结        

        可以看到,线性注意力的核心就是为注意力机制选择了一个易于分解的‘核函数’\phi(Q\cdot K^T)\approx \phi(Q)\cdot \phi(K)^T

        不同的线性注意力变体(如Linear Transformer, Performer, FLASH等)的区别主要就在于特征映射函数\phi(\cdot )的设计上(例如,Performer使用exp和随机投影来近似softmax核,而最早的Linear Transformer使用elu(x)+1等简单函数)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值