Attention is all you need
2017年的论文,首次引入transformer模型
摘要
传统的序列到序列问题是基于循环或卷积神经网络,而作者提出一种更简单的完全基于注意力机制的架构,在机器翻译取得了不错的成绩
介绍
RNN,LSTM,GRN已经建立起了序列模型的坚实基础
循环模型需要 h t − 1 h_{t-1} ht−1到 h t h_t ht状态连续变化,固有的顺序性限制了并行性
注意力机制可以对序列的依赖关系进行建模而不考虑序列中的距离,大多数情况下,注意力机制和循环网络结合使用
我们提出transformer,完全依赖注意力机制,允许更多的并行化
背景
为了解决顺序计算导致的计算速度变慢的问题,有网络使用卷积神经网络作为基本构建块,对所有的输入和输出进行并行计算,但这些模型中将任意两个位置关联起来所需的操作数量随着位置距离的增长而增长,学习远距离的关系变得非常困难
transformer中,学习远距离的关系代价减少到了一个恒定的操作数量——但是也降低了识别关系的有效率,所以需要引入multi-head attention机制来抵消这种影响。
模型架构
编码器和解码器
Encoder
有两个子层
-
多头注意力层:
问:为什么要引入Multi-Head机制?
答:在多头注意力机制中,原始的注意力计算被扩展成多个并行的子注意力计算,每个子注意力计算被称为一个头(head)。每个头都有自己的一组参数(Q、K、V矩阵),并生成一组注意力权重。是为了提升transformer模型识别关系的有效率
-
完全连接的前馈网络(多层感知机):
每个子层后面要接使用残差连接和进行归一化的步骤
-
残差连接:假设H(x)表示子层的输出,x表示子层的输入,那么残差连接的作用是H(x) + x,这样做有助于防止信息在多层传递中逐渐丢失,有助于减轻梯度消失问题,使得网络更容易训练。
-
归一化:
L a y e r N o r m ( H ( x ) i ) = H ( x ) i − μ σ 2 + ϵ × γ + β LayerNorm(H(x)_i)=\frac{H(x)_i-\mu}{\sqrt{\sigma ^2+ \epsilon}}\times \gamma + \beta LayerNorm(H(x)i)=σ2+ϵH(x)i−μ×γ+β
其中:- H(x)_i表示输出的第i个特征维度
- μ \mu μ表示特征维度i上的均值
- σ 2 \sigma^2 σ2表示特征维度i上的方差
- γ \gamma γ是可学习的缩放参数
- β \beta β是可学习的偏移参数
- ϵ \epsilon ϵ是一个小的常数,用于避免除以零
层归一化使得每个特征维度都具有均值为0和方差为1的分布,有助于提高训练稳定性。
Decoder
有三个子层:
-
掩码多头注意力层:
问:为什么要引入Masked?
答:为了保持自回归性质,确保在生成输出序列时只能依赖于之前已生成的内容,同时防止信息泄漏,避免生成不合理或重复的序列部分,从而提高生成质量。
-
多头注意力层:对编码器的输出Q、K和解码器上一步的输出V进行混合计算
-
全连接前馈网络MLP
多头注意力机制
允许模型在不同位置共同注意来自不同子空间的信息
Transformer模型中使用注意力机制的地方
- Decoder中的第二个子层
这个子层中,q来自解码器层,而kv矩阵来自编码器的输出
-
Encoder中的第一个子层:是一个正常的多头注意力层
-
Decoder中的第一个子层:是一个掩码多头注意力层
前馈神经网络
encoder和decoder的每个子层都会接这个网络,包括两个线性转化,中间有一个Relu激活
F
F
N
(
x
)
=
R
e
L
U
(
W
1
x
+
b
1
)
×
W
2
+
b
2
FFN(x) = ReLU(W_1 x + b_1) \times W_2 + b_2
FFN(x)=ReLU(W1x+b1)×W2+b2
位置编码
因为Transformer模型不包括递归和卷积,为了使模型利用序列的顺序,手动加入相对或绝对位置的信息
使用不同频率的正余弦函数
为什么选择Self-Attention
使用Self-Attention主要有以下三个原因:
- 每层的总计算复杂度
- 可并行化的计算量:通过所需的最小顺序操作数来衡量
- 网络中远程依赖关系之间的路径长度
训练
在WMT2014英语-德语数据集上进行训练
) = ReLU(W_1 x + b_1) \times W_2 + b_2
$$
位置编码
因为Transformer模型不包括递归和卷积,为了使模型利用序列的顺序,手动加入相对或绝对位置的信息
使用不同频率的正余弦函数
为什么选择Self-Attention
使用Self-Attention主要有以下三个原因:
- 每层的总计算复杂度
- 可并行化的计算量:通过所需的最小顺序操作数来衡量
- 网络中远程依赖关系之间的路径长度
训练
在WMT2014英语-德语数据集上进行训练