Attention is all you need
论文中的实验分析部分罗列了self-attention和rnn的复杂度对比,特此记录一下自己对二者复杂度的分析。
注意:n表示序列长度,d表示向量维度。
1、self-attention的复杂度为O(n2⋅d)O(n^{2} \cdot d)O(n2⋅d),其来源自self-attention计算公式:
Attention(Q,K,V)=Softmax(QKTdk)VAttention(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})VAttention(Q,K,V)=Softmax(dkQKT)V
其中,Q、K、V∈Rn×dQ、K、V\in \mathbb{R}^{n \times d}Q、K、V∈Rn×d,
QKTQK^{T}QKT是两个矩阵的乘法[n,d]×[d,n]=[n,n][n,d] \times [d,n]=[n,n][n,d]×[d,n]=[n,n],计算复杂度为n2⋅dn^{2} \cdot dn2⋅d;
其结果再乘VVV,即[n,n]×[n,d]=[n,d][n,n] \times [n,d]=[n,d][n,n]×[n,d]=[n,d],计算复杂度也为n2⋅dn^{2} \cdot dn2⋅d;
2、RNN的复杂度为O(n⋅d2)O(n \cdot d^{2})O(n⋅d2),其来源自计算公式:
ht=f(Wxhxt+bxh+Whhht−1+bhh)h_{t}=f(W_{xh}x_{t}+b_{xh}+W_{hh}h_{t-1}+b_{hh})ht=f(Wxhxt+bxh+Whhht−1+bhh) yt=g(Whyht+bht)y_{t}=g(W_{hy}h_{t}+b_{ht})yt=g(Whyht+bht)
Wxh∈Remb×dW_{xh}\in \mathbb{R}^{emb \times d}Wxh∈Remb×d,Whh∈Rd×dW_{hh}\in \mathbb{R}^{d \times d}Whh∈Rd×d,
从Whhht−1W_{hh}h_{t-1}Whhht−1来看,虽然WhhW_{hh}Whh在前边,但是做矩阵乘法的时候是 ht−1×WhhTh_{t-1} \times W_{hh}^{T}ht−1×WhhT,即[1,d]×[d,d]=[1,d][1,d] \times [d,d]=[1,d][1,d]×[d,d]=[1,d],计算复杂度为d⋅dd \cdot dd⋅d;
以上是一个输入的计算复杂度,n个输入的计算复杂度为n⋅d2n \cdot d^{2}n⋅d2。