Google 2017年的论文 Attention is all you need 阐释了什么叫做大道至简!该论文提出了
Transformer
模型,完全基于
Attention mechanism
,抛弃了传统的
RNN
和
CNN
。
我们根据论文的结构图,一步一步使用 PyTorch 实现这个
Transformer
模型。
Transformer架构
首先看一下transformer的结构图:
解释一下这个结构图。首先,
Transformer
模型也是使用经典的
encoer-decoder
架构,由encoder和decoder两部分组成。
上图的左半边用
Nx
框出来的,就是我们的encoder的一层。encoder一共有6层这样的结构。
上图的右半边用
Nx
框出来的,就是我们的decoder的一层。decoder一共有6层这样的结构。
输入序列经过
word embedding
和
positional encoding
相加后,输入到encoder。
输出序列经过
word embedding
和
positional encoding
相加后,输入到decoder。
最后,decoder输出的结果,经过一个线性层,然后计算softmax。
word embedding
和
positional encoding
我后面会解释。我们首先详细地分析一下encoder和decoder的每一层是怎么样的。
Encoder
encoder由6层相同的层组成,每一层分别由两部分组成:
- 第一部分是一个
multi-head self-attention mechanism - 第二部分是一个
position-wise feed-forward network
,是一个全连接层
两个部分,都有一个
残差连接(residual connection)
,然后接着一个
Layer Normalization
。
如果你是一个新手,你可能会问:
- multi-head self-attention 是什么呢?
- 参差结构是什么呢?
- Layer Normalization又是什么?
这些问题我们在后面会一一解答。
Decoder
和encoder类似,decoder由6个相同的层组成,每一个层包括以下3个部分:
- 第一个部分是
multi-head self-attention mechanism - 第二部分是
multi-head context-attention mechanism - 第三部分是一个
position-wise feed-forward network
还是和encoder类似,上面三个部分的每一个部分,都有一个
残差连接
,后接一个
Layer Normalization
。
但是,decoder出现了一个新的东西
multi-head context-attention mechanism
。这个东西其实也不复杂,理解了
multi-head self-attention
你就可以理解
multi-head context-attention
。这个我们后面会讲解。
Attention机制
在讲清楚各种attention之前,我们得先把attention机制说清楚。
通俗来说,
attention
是指,对于某个时刻的输出
y
,它在输入
x
上各个部分的注意力。这个注意力实际上可以理解为
权重
。
attention机制也可以分成很多种。Attention? Attention! 一问有一张比较全面的表格:
Figure 2. a summary table of several popular attention mechanisms.
上面第一种
additive attention
你可能听过。以前我们的seq2seq模型里面,使用attention机制,这种**加性注意力(additive attention)**用的很多。Google的项目 tensorflow/nmt 里面使用的attention就是这种。
为什么这种attention叫做
additive attention
呢?很简单,对于输入序列隐状态
和输出序列的隐状态
,它的处理方式很简单,直接
合并
,变成
但是我们的transformer模型使用的不是这种attention机制,使用的是另一种,叫做
乘性注意力(multiplicative attention)
。
那么这种
乘性注意力机制
是怎么样的呢?从上表中的公式也可以看出来:
两个隐状态进行点积
!
Self-attention是什么?
到这里就可以解释什么是
self-attention
了。
上面我们说attention机制的时候,都会说到两个隐状态,分别是
和
,前者是输入序列第i个位置产生的隐状态,后者是输出序列在第t个位置产生的隐状态。
所谓
self-attention
实际上就是,
输出序列
就是
输入序列
!因此,计算自己的attention得分,就叫做
self-attention
!
Context-attention是什么?
知道了
self-attention
,那你肯定猜到了
context-attention
是什么了:
它是encoder和decoder之间的attention
!所以,你也可以称之为
encoder-decoder attention
!
context-attention
一词并不是本人原创,有些文章或者代码会这样描述,我觉得挺形象的,所以在此沿用这个称呼。其他文章可能会有其他名称,但是不要紧,我们抓住了重点即可,那就是
两个不同序列之间的attention
,与
self-attention
相区别。
不管是
self-attention
还是
context-attention
,它们计算attention分数的时候,可以选择很多方式,比如上面表中提到的:
- additive attention
- local-base
- general
- dot-product
- scaled dot-product
那么我们的Transformer模型,采用的是哪种呢?答案是:
scaled dot-product attention
。
Scaled dot-product attention是什么?
论文Attention is all you need里面对于attention机制的描述是这样的:
An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.
这句话描述得很清楚了。翻译过来就是:通过确定Q和K之间的相似程度来选择V!
用公式来描述更加清晰:
上面公式中的
表示的是K的维度,在论文里面,默认是
64
。
那么为什么需要加上这个缩放因子呢?论文里给出了解释:对于
很大的时候,点积得到的结果维度很