PyTorch 实现Transformer模型:Encoder-Decoder详解

Google 2017年的论文 Attention is all you need 阐释了什么叫做大道至简!该论文提出了
Transformer
模型,完全基于
Attention mechanism
,抛弃了传统的
RNN

CNN

我们根据论文的结构图,一步一步使用 PyTorch 实现这个
Transformer
模型。

Transformer架构

首先看一下transformer的结构图:

解释一下这个结构图。首先,
Transformer
模型也是使用经典的
encoer-decoder
架构,由encoder和decoder两部分组成。

上图的左半边用
Nx
框出来的,就是我们的encoder的一层。encoder一共有6层这样的结构。

上图的右半边用
Nx
框出来的,就是我们的decoder的一层。decoder一共有6层这样的结构。

输入序列经过
word embedding

positional encoding
相加后,输入到encoder。

输出序列经过
word embedding

positional encoding
相加后,输入到decoder。

最后,decoder输出的结果,经过一个线性层,然后计算softmax。

word embedding

positional encoding
我后面会解释。我们首先详细地分析一下encoder和decoder的每一层是怎么样的。

Encoder

encoder由6层相同的层组成,每一层分别由两部分组成:

  • 第一部分是一个
    multi-head self-attention mechanism
  • 第二部分是一个
    position-wise feed-forward network
    ,是一个全连接层

两个部分,都有一个
残差连接(residual connection)
,然后接着一个
Layer Normalization

如果你是一个新手,你可能会问:

  • multi-head self-attention 是什么呢?
  • 参差结构是什么呢?
  • Layer Normalization又是什么?

这些问题我们在后面会一一解答。

Decoder

和encoder类似,decoder由6个相同的层组成,每一个层包括以下3个部分:

  • 第一个部分是
    multi-head self-attention mechanism
  • 第二部分是
    multi-head context-attention mechanism
  • 第三部分是一个
    position-wise feed-forward network

还是和encoder类似,上面三个部分的每一个部分,都有一个
残差连接
,后接一个
Layer Normalization

但是,decoder出现了一个新的东西
multi-head context-attention mechanism
。这个东西其实也不复杂,理解了
multi-head self-attention
你就可以理解
multi-head context-attention
。这个我们后面会讲解。

Attention机制

在讲清楚各种attention之前,我们得先把attention机制说清楚。

通俗来说,
attention
是指,对于某个时刻的输出
y
,它在输入
x
上各个部分的注意力。这个注意力实际上可以理解为
权重

attention机制也可以分成很多种。Attention? Attention! 一问有一张比较全面的表格:

Figure 2. a summary table of several popular attention mechanisms.

上面第一种
additive attention
你可能听过。以前我们的seq2seq模型里面,使用attention机制,这种**加性注意力(additive attention)**用的很多。Google的项目 tensorflow/nmt 里面使用的attention就是这种。

为什么这种attention叫做
additive attention
呢?很简单,对于输入序列隐状态
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
和输出序列的隐状态
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
,它的处理方式很简单,直接
合并
,变成
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

但是我们的transformer模型使用的不是这种attention机制,使用的是另一种,叫做
乘性注意力(multiplicative attention)

那么这种
乘性注意力机制
是怎么样的呢?从上表中的公式也可以看出来:
两个隐状态进行点积

Self-attention是什么?

到这里就可以解释什么是
self-attention
了。

上面我们说attention机制的时候,都会说到两个隐状态,分别是
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
,前者是输入序列第i个位置产生的隐状态,后者是输出序列在第t个位置产生的隐状态。

所谓
self-attention
实际上就是,
输出序列
就是
输入序列
!因此,计算自己的attention得分,就叫做
self-attention

Context-attention是什么?

知道了
self-attention
,那你肯定猜到了
context-attention
是什么了:
它是encoder和decoder之间的attention
!所以,你也可以称之为
encoder-decoder attention
!

context-attention
一词并不是本人原创,有些文章或者代码会这样描述,我觉得挺形象的,所以在此沿用这个称呼。其他文章可能会有其他名称,但是不要紧,我们抓住了重点即可,那就是
两个不同序列之间的attention
,与
self-attention
相区别。

不管是
self-attention
还是
context-attention
,它们计算attention分数的时候,可以选择很多方式,比如上面表中提到的:

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

那么我们的Transformer模型,采用的是哪种呢?答案是:
scaled dot-product attention

Scaled dot-product attention是什么?

论文Attention is all you need里面对于attention机制的描述是这样的:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

这句话描述得很清楚了。翻译过来就是:通过确定Q和K之间的相似程度来选择V!

用公式来描述更加清晰:

上面公式中的
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
表示的是K的维度,在论文里面,默认是
64

那么为什么需要加上这个缩放因子呢?论文里给出了解释:对于
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
很大的时候,点积得到的结果维度很

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值