公式
这篇博客主要讲基础的 seq2seq 中 attention 机制:
- 输入:X=(x1,x2,⋯ ,xTx)X = (x_1, x_2, \cdots, x_{T_x})X=(x1,x2,⋯,xTx)
- 输出:Y=(y1,y2,⋯ ,yTy)Y = (y_1, y_2, \cdots, y_{T_y})Y=(y1,y2,⋯,yTy)
公式推导:
- ht=RNNenc(xt,ht−1)h_t = RNN_{enc}(x_t, h_{t-1})ht=RNNenc(xt,ht−1), Encoder 只要它的 hidden state。
- st=RNNdec(yt−1,st−1)s_t = RNN_{dec}(y_{t-1}, s_{t-1})st=RNNdec(yt−1,st−1),这里相当于用 teacher forcing,这里 s 也是指 hidden state。
此时,做 attention:
- eij=score(si−1,hj)e_{ij} = score(s_{i-1}, h_j)eij=score(si−1,hj),si−1s_{i-1}si−1 先跟每个 hhh 分别计算的得到一个分数,这样所以 sss 计算后,得到一个矩阵,行相当于代表每个词,列相当于分配给每个 hhh 的权重。也就是每个decoder的 hidden state 与每个 encoder 的 hidden state 计算一个相似度。
- αij=exp(eij)∑k=1Txexp(eik)\alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x} exp(e_{ik})}αij=∑k=1Txexp(eik)exp(eij),softmax操作。
- ci=∑j=1Txαijhjc_i = \sum_{j=1}^{T_x}\alpha_{ij}h_jci=∑j=1Txαijhj,attention output。
最后
- st^=tanh(Wc[ct;st])\hat{s_t} = tanh(W_c[c_t; s_t])st^=tanh(Wc[ct;st]),concate 操作,WWW 为需要学习的参数。
- p(yt∣y<t,x)=softmax(Wsst^)p(y_t|y_{<t}, x) = softmax(W_s\hat{s_t})p(yt∣y<t,x)=softmax(Wsst^),输出概率。
备注:计算score那一步有几种操作,可以直接点乘、加一个可学习矩阵相乘、cos相似度、多层感知机等:
score(si,hi)={siThisiTWhivTtanh(W[si;hi])score(s_i, h_i)=\left\{
\begin{aligned}
& s_i^{\mathrm T}h_i \\
& s_i^{\mathrm T}Wh_i \\
& v^{\mathrm T}tanh(W[s_i;h_i])
\end{aligned}
\right.score(si,hi)=⎩⎪⎨⎪⎧siThisiTWhivTtanh(W[si;hi])