transformer学习| Encoder
前言
今天开始学习DataWhale组队学习fun-transformer,今天学习task3 Encoder部分!
【教程地址】
【开源地址】
1. 编码器(Encoder)
1.1 Encoder 组成成分
在Transformer模型中,整个编码器是由多个相同的编码器层(Encoder Layer)堆叠而成的。每一层编码器层都像是一个“小工厂”,负责处理输入的数据,并将处理后的结果传递给下一层,每一层编码器层的结构都是一样的。
每一层编码器层内部主要包含两个“工作车间”(子层):
- 多头自注意力层(Multi-Head Self-Attention Layer)
这个“车间”的工作是让输入的句子中的每个单词都能“看到”其他单词,并且根据它们之间的关系进行加权处理。
具体而言,它通过缩放点积注意力(Scaled Dot-product Attention)来计算单词之间的关联强度,并且通过多头注意力(Multi-Head Attention)机制从多个角度捕捉这些关系。
举个例子,假设我们有句话“我爱你”,多头自注意力层可以让“我”看到“爱”和“你”,并且计算出“我”和“爱”、“我”和“你”之间的关系有多重要。
在这个过程中,还会用到残差连接(Residual Connection)和层归一化(Layer Normalization)。
残差连接的意思是,处理完的结果会和原始输入相加,这样可以避免网络太深导致的信息丢失。
层归一化则是对数据进行标准化处理,让模型训练得更稳定。
2.逐位置前馈网络层(Position-wise Feed-Forward Network Layer)
这个“车间”的工作是对每个单词分别进行进一步的处理。它是一个简单的神经网络,会对每个单词的特征进行变换,让模型能够学习到更复杂的模式。
比如,经过多头自注意力层处理后的“我”、“爱”、“你”会在这个层中被分别加工,让每个单词的特征更加丰富。
1.2 Encoder 工作流程
1.2.1 输入阶段
- 初始输入:整个 Encoder 部分由 6 个相同的子模块按顺序连接构成。第一个 Encoder 子模块接收来自嵌入(Input Embedding)和位置编码(Position Embedding)组合后的输入(inputs)。输入嵌入是将输入的原始数据(如文本中的单词等)转化为向量表示,方便模型处理。而位置编码的作用是让模型能够捕捉到输入序列中元素的位置信息,因为标准的向量表示本身没有位置概念,位置编码可以帮助模型区分同一单词在不同位置时的不同含义和作用,这对于理解句子等序列数据的语义至关重要。
- 后续 Encoder 输入:除了第一个 Encoder 之外的其他 Encoder 子模块,它们从前一个 Encoder 接收相应的输入(inputs),这样就形成了一个顺序传递信息的链路。这种顺序传递的方式使得每个 Encoder 子模块都能在前一个子模块处理的基础上,进一步对信息进行加工和提取,让模型对输入序列的理解和表示逐渐深入和准确。
1.2.2 核心处理阶段
- 多头自注意力层处理 :每个 Encoder 子模块在接收到输入后,首先会将其传递到多头自注意力层(Multi - Head Self - Attention layer)。在这一层中,通过多头自注意力机制(查询Q、键K、值V都来自同一个输入序列自身)去计算输入序列不同位置之间的关联关系,生成相应的自注意力输出。多头自注意力机制允许模型在不同的子空间中学习到不同的关系,每个头都有自己的 Q、K 和 V,最后将所有头的输出通过一个线性层拼接起来。这种机制可以让模型同时关注输入序列中的多个不同方面,比如在文本处理中,可以同时捕捉单词之间的语义关联、语法结构依存关系等,从而更全面地理解输入序列。
- 前馈层处理 :自注意力层的输出紧接着被传递到前馈层(Feedforward layer)。前馈层一般是由全连接网络等构成,对自注意力层输出的特征做进一步的非线性变换,提取更复杂、高层次的特征。通过这种非线性变换,模型能够捕捉到输入数据中更复杂的模式和规律,从而更好地理解输入序列的含义。然后将其输出向上发送到下一个编码器(如果不是最后一个 Encoder 的话),以便后续 Encoder 子模块继续进行处理,进一步优化特征表示。
1.2.3 残差与归一化阶段
- 残差连接(Residual Connection) :自注意力层和前馈子层(Feedforward sublayer)均配备了残差快捷链路。这种连接方式构建了一种并行通路,使得输入信号能够以残差的形式参与到每一层的输出计算中。对于自注意力层,其输入会与自注意力层的输出进行相加操作(假设自注意力层输入为 x,输出为 y,经过残差连接后变为 x + y),前馈层的输入也会和前馈层的输出进行相加。残差连接有助于缓解深度网络训练过程中的梯度消失或梯度爆炸问题,使得网络能够更容易地训练深层模型,并且能够让信息更顺畅地在网络中传递,保证了模型在较深的网络结构下仍能有效地学习和更新参数。
- 层归一化(Layer Norm) :在残差连接之后,紧跟着会进行层归一化操作。层归一化是对每一层的神经元的输入进行归一化处理,它可以加速网络的收敛速度、提高模型的泛化能力等,使得模型训练更加稳定、高效。经过层归一化后的结果就是当前 Encoder 子模块最终的输出,然后传递给下一个 Encoder 子模块或者后续的其他模块(比如在 Encoder - Decoder 架构中传递给 Decoder 部分等情况)。层归一化通过调整每一层神经元的输入分布,使其保持相对稳定,避免了由于输入分布的变化导致的训练不稳定和收敛速度慢等问题,同时也有助于提高模型在不同数据集和任务上的泛化性能。
2. 多头自注意力(Multi-Head Self-Attention)
- 多头注意力机制
- 子空间学习:Transformer模型采用多头注意力机制,可以将输入数据分割成多个不同的子空间,每个子空间对应一个“头”。每个头都有自己的Q(查询)、K(键)和V(值),这样模型就能在不同的子空间中学习到不同的关系和特征。比如在一个文本序列中,一个头可能专注于学习单词之间的语法关系,而另一个头可能侧重于学习语义关联。
- 线性层拼接输出:各个头分别计算出自己的输出后,这些输出会被拼接起来,然后通过一个线性层进行整合,得到最终的多头注意力输出。这个线性层的作用是将不同头学习到的信息融合在一起,使模型能够综合考虑各个子空间中的关系和特征,从而更全面地理解输入数据。
- Q、K、V的含义及起源
- Query(查询):在注意力机制中,Query代表了正在询问的信息或关心的上下文。以文本处理为例,在自注意力机制里,序列中的每个元素(如单词)都会有一个对应的查询向量,它就像是在向其他元素发出询问,试图从其他部分找到与自己相关的信息,以便更好地理解自身在整个序列中的含义和作用。
- Key(键):Key是可以被查询的条目或“索引”。在自注意力机制中,每个序列元素同样对应一个键向量,它就像是一个标识,能够让其他元素通过查询向量来判断与自己的相关性。当一个查询向量与某个键向量匹配程度较高时,就说明对应的元素与查询元素有较强的关联。
- Value(值):对于每一个键,都有一个与之关联的值,它代表实际的信息内容。当查询匹配到一个特定的键时,其对应的值就会被选中并返回,作为对查询的一种回答,为模型提供有用的信息,帮助模型更好地理解和处理序列数据。
- 与数据库查询概念的相似性:
这种Q、K、V的叫法和它们在注意力机制中的作用,与数据库中的查询概念有相似之处。在数据库查询中,用户输入查询(Query),数据库会根据查询内容在存储的数据中查找与之匹配的键(Key),然后返回对应的值(Value),从而为用户提供所需的信息。 - 子空间的直观理解
可以将子空间想象成一个“局部视角”或“局部特征空间”。每个头通过自己的子空间,专注于学习数据的某个特定方面,而这些方面最终被整合起来,形成对数据的全面理解。
2.1 缩放点积注意力(Scaled Dot-Product Attention)
查询、键、值的生成
- 输入:自注意力机制的输入是序列词向量,记为 x x x 。
- 线性变换:输入
x
x
x 经过三个不同的线性变换,分别得到查询(Query,记为
Q
Q
Q)、键(Key,记为
K
K
K)和值(Value,记为
V
V
V)。
- 查询 Q Q Q是通过线性变换 l i n e a r q ( x ) linear_q(x) linearq(x)得到的,表示模型想要查询或者关注的信息。
- 键 K K K是通过线性变换 l i n e a r k ( x ) linear_k(x) lineark(x)得到的,相当于数据库中的“索引”,用于与其他查询进行匹配。
- 值 V V V是通过线性变换 l i n e a r v ( x ) linear_v(x) linearv(x)得到的,包含了实际的信息内容,当查询与某个键匹配时,对应的值就会被选中并返回。
- 这三个线性变换是相互独立的,即
l
i
n
e
a
r
q
(
x
)
linear_q(x)
linearq(x)、
l
i
n
e
a
r
k
(
x
)
linear_k(x)
lineark(x)、
l
i
n
e
a
r
v
(
x
)
linear_v(x)
linearv(x) 使用不同的参数矩阵对输入
x
x
x进行映射,从而得到不同用途的向量
Q
Q
Q、
K
K
K、
V
V
V 。
注意力权重的计算与归一化
- 点积计算:对于序列中的每一对查询和键,计算它们的点积 Q K T QK^T QKT,得到一个分数矩阵,这个分数表示查询和键之间的相似性或者匹配程度,分数越高,说明查询和键越相似,对应的值就越重要。
- 缩放操作:为了避免点积结果过大导致后续 softmax 函数计算时出现数值不稳定问题,引入了缩放因子 d k \sqrt{d_k} dk,其中 d k d_k dk 是键向量的维度。将点积结果除以 d k \sqrt{d_k} dk 进行缩放,得到缩放后的分数矩阵 Q K T d k \frac{QK^T}{\sqrt{d_k}} dkQKT。
- softmax 归一化:对缩放后的分数矩阵应用 softmax 函数,将每个分数转换为一个概率值,即注意力权重(attention weights),这些权重值介于 0 到 1 之间,并且所有权重之和为 1。这样可以确保每个查询对应的权重分布是合理的,能够反映不同键与查询的相对重要性。
注意力输出的计算
- 加权求和:使用得到的注意力权重对值向量 (V) 进行加权求和,即 (Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V) 。具体来说,对于每个查询,根据其对应的注意力权重,对所有值向量进行加权,得到一个综合考虑了上下文信息的输出向量。
- 最终结果:这个输出向量就是自注意力机制的最终结果,它融合了序列中不同位置的信息,使得每个位置的表示都包含了其他位置的相关信息,从而能够更好地捕捉序列内部的依赖关系和上下文信息。
2.1.1 如何得到缩放因子
缩放因子dk的计算方式
在多头注意力机制里,每个头的维度
d
k
d_k
dk是由模型的总维度
d
m
o
d
e
l
d_{model}
dmodel和头数h共同决定的,计算公式是
d
k
=
d
m
o
d
e
l
h
d_k=\frac{d_{model}}{h}
dk=hdmodel。例如,若
d
m
o
d
e
l
d_{model}
dmodel为512,h为8,那么
d
k
d_k
dk就等于64。
这样计算缩放因子的原因
- 参数平衡:采用这种计算方法,能够保证每个头所拥有的参数数量是相同的,并且使得整体多头注意力的参数总量与单头注意力机制相近。这有利于模型在扩展时保持性能的稳定性,方便对模型进行管理和优化。
- 计算效率:由于 d k d_k dk是 d m o d e l d_{model} dmodel的因子,在进行矩阵运算时,这种关系能够更好地利用硬件的加速功能,从而提高计算的效率,加快模型的训练和推理速度。
- 多样性:当有多个头时,每个头都在各自的子空间中进行操作。通过合理设置 d k d_k dk,可以让模型在不同的子空间中捕获输入数据之间更丰富多样的关系,从而提升模型对数据的理解和表达能力。
- 可解释性和调试:选择合适的 d k d_k dk,可以使每个注意力头的维度相对较小,这样在分析和解释每个头的作用以及进行调试时会更加容易。在一些特殊的场景下,还可以根据具体需求手动设置dk的值。
2.1.2 缩放因子的作用
关于缩放因子
- 缩放因子的引入原因:正常的点积注意力中,Q和K矩阵相乘后的内积会随着维度的增加而增大。而当内积过大时,经过softmax函数处理后,很容易落入饱和区间。在softmax函数的饱和区间,其梯度值会趋近于0,这会导致梯度消失问题,进而影响模型的训练效果,使模型难以收敛。
- 缩放因子的作用:缩放因子 d k \sqrt{d_k} dk 的引入,可以对Q和K的内积进行缩放,使其保持在一个合理的范围内,避免内积过大。这样,经过softmax函数处理后,输出值不会轻易进入饱和区间,从而保证了梯度的稳定,避免梯度消失问题的发生,有助于模型的稳定训练。
矩阵相乘与Attention层
- 矩阵相乘的过程:如果忽略激活函数softmax,那么在计算注意力时,实际上是Q、K、V三个矩阵相乘。具体过程是先计算Q和K的转置的点积,得到一个注意力分数矩阵,再将该分数矩阵与V矩阵相乘,最终得到一个维度为 n ⋅ d v n \cdot d_v n⋅dv 的矩阵。其中,n表示序列的长度, d v d_v dv表示V矩阵的维度。
- Attention层的功能:从这个过程可以看出,Attention层的作用是将输入序列Q编码成一个新的序列,其维度为 n ⋅ d v n \cdot d_v n⋅dv。在这个新的序列中,每个元素都是根据Q、K、V之间的关系计算得到的,能够更好地捕捉序列中的关键信息和依赖关系,从而为后续的模型层提供更有用的特征表示。
mask部分的作用
- 训练时关闭mask的原因:在训练阶段,模型的目标是学习整个序列中的信息和关系,以便更好地理解和预测序列中的每个元素。因此,此时不需要对序列进行遮蔽,而是让模型能够充分地利用整个序列的信息来学习,从而提高模型的性能。
- 测试或推理时打开mask的作用:在测试或推理阶段,模型通常是逐个生成序列中的元素,例如在自然语言处理中的文本生成任务中,模型需要根据已生成的部分来预测下一个词。此时,为了保证预测的合理性,需要使用mask来遮蔽当前预测词后面的序列,防止模型在预测时“看到”未来的信息,从而确保预测结果的准确性和合理性。
2.1.3 计算 attention 时为何选择点乘
计算效率方面
- 点乘注意力:计算效率更高,它可以通过矩阵乘法进行并行优化。在实际硬件中,矩阵乘法操作可以充分利用硬件的并行计算能力,从而实现高效的计算。这种特性使得点乘注意力特别适合大规模的模型训练和推理。在大规模模型中,有大量的数据和复杂的计算需求,高效的计算方式能够显著加快模型的训练和推理速度,提高整个系统的运行效率。
- 加法注意力:在计算效率上相对较低,没有像点乘注意力那样明显的并行优化优势,因此在大规模模型训练和推理时,其速度可能会受到一定限制。
计算复杂度方面
- 理论复杂度:从理论上讲,点乘和加法注意力的计算复杂度都是 O ( d ) O(d) O(d),其中 d d d 表示向量的维度。这意味着在理想情况下,随着向量维度的增加,两种注意力机制的计算量都会线性增长。
- 实际硬件表现:尽管理论复杂度相同,但在实际硬件环境中,点乘注意力通过并行化能够显著提升计算速度。这是因为点乘操作可以很容易地分解为多个独立的子操作,并且这些子操作可以在现代硬件(如 GPU)上同时进行,从而充分利用硬件的并行计算能力,实现高效的计算。而加法注意力在硬件并行化方面没有点乘注意力那么高效,因此在实际应用中,其计算速度可能会相对较慢。
效果方面
- 点乘注意力:
- 能够有效衡量向量的相似性。这是因为它计算的是两个向量对应元素乘积的和,这种计算方式能够直观地反映两个向量在各个维度上的匹配程度,从而较好地衡量它们的相似性。
- 在高维度向量时,通过缩放可以避免数值不稳定问题。当向量的维度很高时,点乘的结果可能会变得非常大,导致数值不稳定,影响模型的训练和推理效果。通过引入缩放因子,可以将点乘的结果控制在一个合理的范围内,从而避免数值不稳定问题,确保模型的稳定性和准确性。
- 加法注意力:
- 由于引入了非线性操作,效果上并无显著提升。加法注意力在计算过程中会引入一些非线性变换,这些非线性操作虽然在理论上可以增加模型的表达能力,但在实际应用中,并没有明显提升衡量向量相似性的效果。
- 计算更为复杂。加法注意力的计算过程相对复杂,涉及到更多的操作步骤和参数,这不仅增加了计算量,还可能导致模型的训练和推理速度变慢,同时也增加了模型的复杂度和调试难度。
2.2 多头注意力机制(Multi-Head Attention)
多头注意力机制的工作方式
- 多头的含义:多头注意力机制是将注意力机制进行了扩展,将输入的查询(Q)、键(K)和值(V)分别通过不同的参数矩阵映射到多个不同的子空间中,形成多个“头”,每个头都独立地计算注意力,然后将这些头的输出进行拼接,最后通过一个线性变换得到最终的输出。这种多头的设计可以让模型在不同的子空间中学习到不同的关系,从而更全面地捕捉输入数据中的信息。
- 具体计算过程:对于每个头,都会对输入的Q、K、V进行线性变换,即分别乘以对应的权重矩阵 ( W i Q (W_i^Q (WiQ、 W i K W_i^K WiK、 W i V W_i^V WiV,然后对变换后的结果应用注意力机制,得到每个头的输出 h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)。其中, W i Q W_i^Q WiQ、 W i K W_i^K WiK、 W i V W_i^V WiV 是权重矩阵,其维度是根据模型的维度和头的数量来确定的。
- 拼接与线性变换:将所有头的输出 h e a d 1 , h e a d 2 , . . . , h e a d h head_1,head_2,...,head_h head1,head2,...,headh 进行拼接,得到一个维度为 n × ( h d ~ v ) n×(h\tilde{d}_v) n×(hd~v)的序列,其中 n是序列的长度,h是头的数量, d ~ v \tilde{d}_v d~v 是每个头输出的维度。然后对拼接后的结果进行线性变换,得到最终的多头注意力输出。
头的数量和权重矩阵的设置
- 头的数量:头的数量 (h) 是一个超参数,需要根据具体任务和数据来调整。一般来说,头的数量并不是越大越好,当头的数量达到一定程度后,模型的性能提升会趋于平稳,甚至可能会出现性能下降的情况。在Google的Transformer模型中,头的数量通常设置为8。
- 权重矩阵的不同:每个头的权重矩阵 W i Q W_i^Q WiQ、 W i K W_i^K WiK、 W i V W_i^V WiV 是不同的,这使得每个头可以学习到不同的特征和关系,从而增加了模型的多样性和表示能力。
- 实验观察:有实验表明,随着层数的增加,头之间的差异会逐渐减小,这意味着在较深的层次上,不同头所学习到的信息可能会趋于相似。因此,这种头之间的差异性是否是模型真正追求的,还需要进一步的研究和验证。
2.2.1 多头注意力机制的优点:
- 并行计算:由于每个头的计算是独立的,所以可以并行进行,这大大提高了计算效率,使得模型能够更快地处理数据。
- 增强表示能力:每个头可以关注输入序列的不同方面,从而使得模型能够捕捉到更丰富、更复杂的模式和关系,增强了模型对输入数据的表示能力。
- 提高泛化性:多头注意力机制使得模型能够同时关注序列内的局部和全局依赖关系,这有助于提高模型在不同任务和领域的泛化能力。
2.2.2 多头注意力的计算
线性变换
- 输入序列的投影:输入序列首先会经历可学习的线性变换。具体来说,就是将输入序列映射到多个不同的子空间中,每个子空间对应一个“头”(head)。这个过程可以类比为将一个高维的数据空间分解为多个低维的子空间,每个子空间都包含原始数据的一部分特征。
- “头”的作用:每个“头”都会关注输入序列的不同方面。这就像是从不同的角度去观察和理解输入序列,使得模型能够捕捉到输入序列中各种不同的模式和关系。例如,在处理文本数据时,一个头可能专注于捕捉句子中的语法结构,而另一个头可能更关注单词之间的语义关联。
缩放点积注意力
- 计算注意力分数:在每个“头”中,都会独立地计算输入序列的查询(Query)、键(Key)和值(Value)表示之间的注意力分数。具体计算方式是通过计算查询向量和键向量的点积来衡量它们之间的相似度。这个相似度分数就表示了当前令牌(token)与其他令牌之间的关联程度。
- 缩放操作:计算得到的点积结果会除以模型深度的平方根进行缩放。缩放的目的是为了防止内积结果过大,从而避免在后续的softmax函数中出现梯度消失的问题。因为softmax函数对于较大的输入值会变得非常平坦,导致梯度几乎为零,这会影响模型的训练效果。
- 注意力权重:经过缩放后的点积结果会通过softmax函数进行归一化,得到注意力权重。这些权重值介于0到1之间,表示每个令牌相对于其他令牌的重要性。权重较高的令牌意味着它在当前上下文中更为重要,对输出的贡献也更大。
连接和线性投影
- 输出的连接:来自所有“头”的注意力输出会被连接起来。这一步是将不同“头”所捕捉到的信息进行整合,将它们的输出合并为一个完整的表示。这样可以充分利用每个“头”所学习到的不同模式和关系,使模型对输入序列的理解更加全面。
- 线性投影回原始维度:连接后的输出会经过一个线性变换,将其投影回原始的维度。这个线性变换的作用是将不同“头”的输出进行融合和调整,使得最终的输出能够更好地适应模型的后续处理。通过这个过程,模型能够将来自多个“头”的见解结合起来,从而增强其对序列内复杂关系的理解能力。
2.3 自注意力机制(Self Attention)
自注意力机制的定义
Cheng 等人在论文《Long Short-Term Memory-Networks for Machine Reading》中将自注意力(Self-Attention)定义为一种将单个序列或句子的不同位置关联起来以获得更有效表示的机制。这意味着它能够让模型在处理序列数据时,不仅仅关注序列中的单个元素,而是同时考虑序列中所有元素之间的相互关系,从而得到更全面、更丰富的序列表示。
自注意力机制的分类
- Encoder Self-Attention:在编码器(Encoder)阶段使用,其目的是捕获当前单词与其他输入单词之间的关联关系。通过这种方式,模型能够理解整个输入序列中各个单词之间的相互作用,为后续的编码过程提供更准确的上下文信息。
- Masked Decoder Self-Attention:在解码器(Decoder)阶段使用,用于捕获当前单词与已经看到的解码词之间的关联。这里的“Masked”表示对解码器的输入进行了遮蔽处理,即在计算注意力时,当前单词只能关注到它之前已经生成的单词,而不能看到未来还未生成的单词。从矩阵的角度来看,这种注意力机制对应的是一个带有遮蔽的三角矩阵,确保了解码过程的自回归性质。
- Encoder-Decoder Attention:这种注意力机制是将解码器(Decoder)和编码器(Encoder)的输入建立联系。它与前面提到的普通注意力机制类似,其作用是让解码器能够利用编码器提供的上下文信息,从而更准确地生成目标序列。
自注意力机制的原理
在自注意力中,Query(查询)、Key(键)和 Value(值)都来自同一个输入序列。其核心思想是通过对输入序列进行三种独立的线性变换,分别得到 Q_x(查询向量)、K_x(键向量)和 V_x(值向量),然后将这三者输入到注意力机制中进行计算,公式表示为 Attention (Q_x, K_x, V_x)。
具体来说,注意力机制的工作原理如下:
- 计算注意力分数:通过计算查询向量(Q_x)与键向量(K_x)之间的点积,得到每个元素之间的注意力分数,这个分数反映了序列中不同元素之间的相关性。
- SoftMax 归一化:使用 SoftMax 函数对注意力分数进行归一化处理,得到注意力权重。这些权重表示每个元素在计算过程中对其他元素的关注程度,权重值介于 0 到 1 之间,且所有权重之和为 1。
- 加权求和:根据得到的注意力权重,对值向量(V_x)进行加权求和,得到最终的自注意力输出。这个输出包含了序列中其他元素的上下文信息,使得每个元素的表示都能够融合整个序列的信息。
自注意力机制在机器翻译和 Seq2Seq 任务中的应用
在机器翻译以及一般的序列到序列(Seq2Seq)任务中,自注意力机制在序列编码方面发挥着关键作用。以往的研究大多只在解码端应用注意力机制,而 Google 的创新之处在于将自注意力机制应用于序列编码阶段,具体采用的是 Self Multi-Head Attention。
- Self Multi-Head Attention 的计算公式:Y = MultiHead(X, X, X),即将同一个序列 X 同时当作查询(Query)、键(Key)以及值(Value)输入到多头注意力机制中进行运算。通过这种方式,模型能够挖掘序列内部的各种联系,例如句子中单词之间的语义关联、语法结构依存关系等。
- Multi-Head-Attention 的具体操作:首先将经过嵌入(embedding)处理后的序列 X 按照维度 d m o d e l d_{model} dmodel = 512 切割成 h = 8 个部分,然后分别对每个部分进行 self-attention 计算,最后将这些计算结果合并在一起。这种多头注意力机制允许模型在不同的子空间中学习到不同的关系,从而更全面地捕捉序列内部的复杂信息。
2.3.1自注意力的工作原理
这段内容详细解释了自注意力机制的工作原理,以句子“The cat sat on the mat.”为例,阐述了自注意力机制如何处理输入序列。以下是对其的解释:
2.3.1 自注意力的工作原理
以句子“The cat sat on the mat.”为例,说明自注意力机制是如何工作的。
(1)嵌入
模型将输入序列中的每个单词转换为高维向量表示。这个过程称为嵌入,目的是让模型能够理解单词之间的语义相似性。例如,“cat”和“dog”这两个单词在嵌入空间中可能会比较接近,因为它们在语义上有相似之处。
(2)查询、键和值向量
模型为每个单词生成三个向量:查询向量(Query)、键向量(Key)和值向量(Value)。在训练过程中,模型会学习这些向量,它们各自有不同的作用:
- 查询向量(Query):表示单词的查询,即模型在序列中寻找的内容。例如,当处理单词“cat”时,查询向量可能表示模型想要知道“cat”与其他单词的关系。
- 键向量(Key):表示单词的键,即序列中其他单词应该注意的内容。例如,当处理单词“sat”时,键向量可能表示“sat”与“cat”或“mat”等单词的关联。
- 值向量(Value):表示单词的值,即单词对输出所贡献的信息。例如,当处理单词“the”时,值向量可能表示“the”在句子中的具体含义或作用。
(3)注意力分数
模型计算每对单词之间的注意力分数。这通常是通过计算查询向量和键向量的点积来实现的,以评估单词之间的相似性。例如,计算“cat”(查询向量)和“mat”(键向量)之间的点积,得到一个分数,表示“cat”对“mat”的关注度。
(4)SoftMax 归一化
使用softmax函数对注意力分数进行归一化,得到注意力权重。这些权重表示每个单词应该关注序列中其他单词的程度。softmax函数确保所有权重的和为1,使得注意力权重在0到1之间。例如,如果“cat”对“mat”的注意力分数最高,那么“cat”对“mat”的注意力权重也会相对较高。
(5)加权求和
最后,模型使用注意力权重对值向量进行加权求和,得到每个单词的自注意力机制输出。这个输出捕获了来自其他单词的上下文信息。例如,“cat”的输出会包含来自“sat”、“on”、“the”和“mat”等单词的信息,从而使得“cat”的表示更加丰富,包含了上下文的语义信息。
2.3.2 Self-Attention优点
参数少
- 含义:自注意力机制的参数数量相对较少。具体来说,其参数量为 O ( n 2 d ) O(n^2d) O(n2d),其中 n n n表示序列的长度,d表示每个序列元素的维度。
- 对比:相比之下,循环神经网络(RNN)的参数量为 O ( n d 2 ) O(nd^2) O(nd2),卷积神经网络(CNN)的参数量为 O ( k n d 2 ) O(knd^2) O(knd2),其中 k是卷积核的大小。
- 优势:当序列长度n远小于维度d时,自注意力机制的计算速度更快,因为其参数量相对较少,计算复杂度更低。
可并行化
- 含义:自注意力机制的计算过程可以并行化处理。
- 对比:
- RNN:需要一步步递推才能捕捉到序列中的信息,即每个时间步的计算依赖于前一个时间步的结果,因此无法并行计算。
- CNN:虽然可以通过层叠来扩大感受野,但其并行化程度相对较低,且需要多层结构来捕捉长距离依赖。
- 自注意力机制:每一步的计算不依赖于上一步的计算结果,因此可以像 CNN 一样并行处理,大大提高了计算效率。
- 优势:并行化处理能够显著提升模型的训练和推理速度,尤其在处理大规模数据时,可以大幅减少计算时间。
捕捉全局信息
- 含义:自注意力机制能够有效地捕捉序列中的全局信息。
- 优势:
- 解决长时依赖问题:在传统的序列模型中,如 RNN,长距离的信息可能会被弱化,导致模型难以捕捉到序列中较远位置之间的依赖关系。而自注意力机制通过计算序列中每个元素之间的注意力权重,能够直接建立任意两个元素之间的联系,从而更好地解决长时依赖问题。
- 一步获取全局信息:自注意力机制只需一步计算,就可以获取整个序列的全局信息。它能够从序列中挑选出重要的信息,即使文本较长,也能抓住重点,避免丢失重要的信息。这使得模型在处理序列数据时,能够更全面地理解序列的语义和结构。
2.3.3 Self-Attention缺点
Self-Attention机制的计算量问题
- 三次线性映射的计算量:在Self-Attention中,首先需要对输入序列X进行三次线性映射,分别得到查询(Query)、键(Key)和值(Value)。这个过程的计算量相对较小,与卷积核大小为3的一维卷积相当,计算复杂度为O(n),其中n是序列的长度。
- 矩阵乘法的计算量:Self-Attention机制需要计算查询和键之间的点积,以及最终的加权求和。这两个操作都涉及到序列自身的矩阵乘法,每次矩阵乘法的计算量为O(n²d),其中d是特征的维度。当序列长度n较大时,这个计算量会变得非常大,导致Self-Attention机制在处理长序列时效率较低。
Self-Attention机制在捕捉位置信息方面的不足
- 问题描述:Self-Attention机制本身无法捕捉序列中的顺序关系,即无法直接学习到序列中元素的位置信息。这是因为Self-Attention是基于内容的相似性来计算注意力权重,而忽略了元素在序列中的位置。
- 解决方案:为了弥补这一不足,可以通过加入位置信息来改善。例如,在Transformer模型中,通常会使用位置编码(Position Embedding)来为每个序列元素添加位置信息。位置编码可以是一个固定的位置向量,也可以是通过学习得到的位置表示。通过这种方式,模型能够在计算注意力权重时同时考虑内容和位置信息。
Transformer模型的局限性
- 实践上的局限性:
- 特定任务表现不佳:存在一些任务,如复制字符串的任务,RNN能够轻松应对,而Transformer则未能有效解决。这可能是因为Transformer在处理这类任务时,难以捕捉到序列中的局部结构和顺序信息。
- 对序列长度的敏感性:当推理过程中遇到的序列长度超出训练时的最大长度时,Transformer的表现不如RNN。这是因为Transformer依赖于位置嵌入来处理序列,而当遇到未曾见过的位置嵌入时,模型可能会出现性能下降的情况。
- 理论上的局限性:与RNN不同,Transformer模型并不具备计算上的通用性,即非图灵完备。这意味着Transformer模型在处理NLP领域中的一些推理和决策等计算密集型问题时,存在固有的局限性,无法独立完成某些复杂的计算任务。相比之下,RNN架构具有更强的通用性,能够更好地处理这类问题。
2.3.4 Add & Norm
Add & Norm操作的解释
在深度学习中,尤其是Transformer架构中,“Add & Norm”是一种非常重要的操作,它结合了残差连接(Add)和层归一化(Normalization),用于提升模型的训练效率和稳定性。
残差连接(Add)
残差连接的核心思想是将子层的输出直接加回到子层的输入上。这种操作可以看作是一种“短路”,允许梯度在反向传播时直接流过网络层,而不需要经过复杂的非线性变换。这在深层网络中尤为重要,因为深层网络很容易出现梯度消失或梯度爆炸的问题。
数学表达:
如果输入是x,子层的输出是
SubLayer
(
x
)
\text{SubLayer}(x)
SubLayer(x),那么残差连接的输出为:
Residual
=
x
+
SubLayer
(
x
)
\text{Residual} = x + \text{SubLayer}(x)
Residual=x+SubLayer(x)
作用:
- 缓解梯度消失问题:通过直接将输入加到输出上,残差连接为梯度提供了一条“捷径”,使得梯度可以更顺畅地反向传播。
- 简化训练过程:即使网络非常深,残差连接也能让模型更容易训练。
层归一化(Norm)
层归一化是一种归一化技术,它与批量归一化(Batch Normalization)不同。层归一化是对单个样本的所有特征进行归一化,而不是对整个批次中的特征进行归一化。
作用:
- 稳定训练过程:通过将每个特征的均值变为0,标准差变为1,层归一化可以减少内部协变量偏移(Internal Covariate Shift),即由于参数更新导致的输入分布变化。
- 提高模型稳定性:归一化可以减少特征之间的尺度差异,避免某些特征在学习过程中占据主导地位,从而提高模型的泛化能力。
- 适应小批量和在线学习:层归一化不依赖于批次大小,因此在处理小批量数据或者在线学习时,依然可以保持稳定。
. Add & Norm的结合
“Add & Norm”操作将残差连接和层归一化结合在一起,具体步骤如下:
- 计算子层的输出: SubLayer ( x ) \text{SubLayer}(x) SubLayer(x)
- 执行残差连接: Residual = x + SubLayer ( x ) \text{Residual} = x + \text{SubLayer}(x) Residual=x+SubLayer(x)
- 应用层归一化: Output = LayerNorm ( Residual ) \text{Output} = \text{LayerNorm}(\text{Residual}) Output=LayerNorm(Residual)
这种组合操作在Transformer架构中非常常见,因为它既利用了残差连接的梯度优势,又利用了层归一化的稳定性优势。
结合例子的解释
假设我们有一个简单的神经网络层,输入是一个二维向量 x = [1, 2],子层的输出是 SubLayer ( x ) = [ 3 , 4 ] \text{SubLayer}(x) = [3, 4] SubLayer(x)=[3,4]。
-
- 残差连接
首先,我们将子层的输出加回到输入上:
Residual = x + SubLayer ( x ) = [ 1 , 2 ] + [ 3 , 4 ] = [ 4 , 6 ] \text{Residual} = x + \text{SubLayer}(x) = [1, 2] + [3, 4] = [4, 6] Residual=x+SubLayer(x)=[1,2]+[3,4]=[4,6]
- 残差连接
-
- 层归一化
接下来,我们对残差连接的结果进行层归一化。假设我们使用简单的均值为0、标准差为1的归一化:
- 计算均值: μ = 4 + 6 2 = 5 \mu = \frac{4 + 6}{2} = 5 μ=24+6=5
- 计算标准差: σ = ( 4 − 5 ) 2 + ( 6 − 5 ) 2 2 = 1 \sigma = \sqrt{\frac{(4-5)^2 + (6-5)^2}{2}} = 1 σ=2(4−5)2+(6−5)2=1
- 归一化:
Output = LayerNorm ( Residual ) = [ 4 − 5 1 , 6 − 5 1 ] = [ − 1 , 1 ] \text{Output} = \text{LayerNorm}(\text{Residual}) = \left[ \frac{4 - 5}{1}, \frac{6 - 5}{1} \right] = [-1, 1] Output=LayerNorm(Residual)=[14−5,16−5]=[−1,1]
- 层归一化
-
- 结果
最终,经过“Add & Norm”操作后,输出为 ( [-1, 1] )。这个输出既保留了残差连接带来的梯度优势,又通过层归一化稳定了特征尺度。
- 结果
在Transformer架构中的应用
我们有一堆数据(就是x序列),它们要经过一个复杂的过程,变得更有用。这个过程就像是一条流水线,有好几道工序。
-
第一道工序:Multi-Head Self-Attention(多头自注意力机制)
这个过程就像是让数据自己“思考”,看看哪些部分比较重要,哪些部分可以互相帮助。就好比在一个团队里,每个人都在观察其他人的工作,看看谁需要帮忙,谁做得好。 -
第二道工序:Add & Norm(加法和归一化)
这一步有点像“调整和校准”。数据经过第一步处理后,会和原来的自己做一个对比,看看有没有变得更好,然后调整一下,让它们保持在一个合理的范围内。这就好比你做完一件事后,回头看看自己原来的想法,调整一下,让自己做得更完美。 -
第三道工序:Feed-Forward Network(前馈神经网络,简称FFN)
这一步就像是给数据再“加工”一下,让它变得更强大。就好比把一个半成品再加工,让它变成一个更完整的东西。 -
第四道工序:Norm(归一化)
这一步和第二步有点像,也是调整和校准。数据经过第三步加工后,可能会变得有点“失控”,所以再调整一下,让它保持在一个合理的范围内。 -
最后一步:残余连接(Residual Connection)
这个过程很重要。在每个工序之后,数据都会和原来的自己做一个对比,看看有没有进步。如果进步了,就保留下来;如果没有进步,就回到原来的样子。这就好比你学了一个新技能,看看自己有没有变得更好,如果没有,就回到原来的状态,再试一次。
整个过程就是一个不断调整、优化的过程,让数据变得更有用。每一步都很重要,而且每一步都会检查一下自己有没有进步,这样就能保证最后的结果是好的。
2.4 前馈全连接网络(Position-wise Feed-Forward Networks)
FFN的结构
FFN层是一个顺序结构,包含以下三个部分:
- 第一个全连接层(FC):对输入数据进行线性变换,其计算公式为 x W 1 + b 1 xW_1 + b_1 xW1+b1,其中x是输入数据, W 1 W_1 W1 是该层的权重矩阵, b 1 b_1 b1 是偏置项。
- ReLU激活层:在第一个全连接层的输出上应用ReLU激活函数,即 m a x ( 0 , x W 1 + b 1 ) max(0, xW_1 + b_1) max(0,xW1+b1)。ReLU函数的作用是引入非线性,使模型能够捕捉更复杂的特征和模式。
- 第二个全连接层(FC):对ReLU激活层的输出再次进行线性变换,计算公式为 [ m a x ( 0 , x W 1 + b 1 ) ] W 2 + b 2 [max(0, xW_1 + b_1)]W_2 + b_2 [max(0,xW1+b1)]W2+b2,其中 W 2 W_2 W2 是该层的权重矩阵, b 2 b_2 b2 是偏置项。
FFN的作用
- 维度变换:FFN层的主要作用是完成输入数据到输出数据的维度变换。通过两个全连接层的线性变换,可以将输入数据从一个维度映射到另一个维度。
- 增加模型表达能力:在两个全连接层之间添加ReLU激活层,引入了非线性,使模型能够捕捉到更复杂的特征和模式,从而增强了模型的表达能力。
FFN的计算过程
以输入数据 x 为例,FFN的计算过程如下:
- 首先,输入数据 (x) 经过第一个全连接层的线性变换,得到 x W 1 + b 1 xW_1 + b_1 xW1+b1。
- 然后,对 x W 1 + b 1 xW_1 + b_1 xW1+b1 应用ReLU激活函数,得到 m a x ( 0 , x W 1 + b 1 ) max(0, xW_1 + b_1) max(0,xW1+b1)。
- 最后,将 m a x ( 0 , x W 1 + b 1 ) max(0, xW_1 + b_1) max(0,xW1+b1)作为输入,经过第二个全连接层的线性变换,得到最终的输出 [ m a x ( 0 , x W 1 + b 1 ) ] W 2 + b 2 [max(0, xW_1 + b_1)]W_2 + b_2 [max(0,xW1+b1)]W2+b2。
权重矩阵的维度
- W 1 W_1 W1 的维度是 (2048,512),表示输入数据的维度为512,经过第一个全连接层后,维度被提升到2048。
- W 2 W_2 W2 的维度是 (512,2048),表示经过ReLU激活层后的数据维度为2048,经过第二个全连接层后,维度被降低回512。
升维和降维的目的
先升维再降维的目的是扩充中间层的表示能力,从而抵抗ReLU带来的模型表达能力的下降。通过增加中间层的维度,模型能够学习到更丰富的特征表示,进而提高模型的性能。
2.5 Multi-Head Attention vs Multi-Head Self-Attention
基础架构相同
- 多头注意力机制的基本架构:它们都基于多头注意力机制,包含多个并行的注意力头(Attention Head)。每个头都有自己的线性变换矩阵,用于计算查询(Query)、键(Key)和值(Value)。
- 缩放点积注意力:在计算过程中,都涉及到缩放点积注意力(Scaled Dot-Product Attention),这是一种计算注意力分数的方法,通过计算查询和键的点积,再进行缩放和归一化,得到注意力权重,最后用这些权重对值进行加权求和,得到最终的注意力输出。
计算流程相似
- 线性变换:两者都需要先对输入进行线性变换,得到查询(Q)、键(K)和值(V)。这是通过可学习的线性变换矩阵实现的,将输入投影到不同的子空间中。
- 注意力计算:每个头独立地计算输入序列的查询、键和值表示之间的注意力分数,通过缩放点积注意力得到注意力权重,然后用这些权重对值进行加权求和。
- 拼接和线性变换:将所有头的输出进行拼接,再通过一个线性变换得到最终的输出。这个过程将来自多个头的见解结合起来,增强了模型理解序列内复杂关系的能力。
查询、键、值的来源不同
- 多头注意力(Multi-Head Attention):
- 不同输入源:查询(Query)、键(Key)和值(Value)可以来自不同的输入源。例如在解码器(Decoder)部分,查询(Query)来自解码器当前的输入,而键(Key)和值(Value)通常来自编码器(Encoder)的输出。
- 作用:这种机制使得模型能够将解码器当前的信息与编码器已经处理好的信息进行关联,从而更好地生成输出序列。比如在机器翻译任务中,解码器可以根据编码器对源语言句子的编码信息,更好地生成目标语言句子。
- 多头自注意力(Multi-Head Self-Attention):
- 同一输入序列:查询(Query)、键(Key)和值(Value)都来自同一个输入序列。这意味着模型关注的是输入序列自身不同位置之间的关系。
- 作用:它可以让模型自己发现句子中不同单词之间的相互关联。例如在句子 “The dog chased the cat” 中,单词 “dog” 与 “chased”、“chased” 与 “cat” 之间的关系可以通过多头自注意力来挖掘,从而更好地理解句子的语义和结构。
功能重点有所差异
- 多头注意力(Multi-Head Attention):
- 融合不同来源的信息:主要用于融合不同来源的信息。例如在机器翻译任务中,它将源语言句子经过编码器编码后的信息(作为 K 和 V)与解码器当前生成的部分目标语言句子(作为 Q)相结合。这样可以帮助解码器在生成目标语言句子时,更好地参考源语言句子的语义和结构,从而生成更准确的翻译。
- 多头自注意力(Multi-Head Self-Attention):
- 挖掘输入序列自身的内在结构和关系:更侧重于挖掘输入序列自身的内在结构和关系。在文本生成任务中,它可以帮助模型理解当前正在生成的文本自身的语义连贯和语法结构。例如在续写一个故事时,通过多头自注意力可以让模型把握已经生成的部分文本的主题、情节发展等内部关系,以便更好地续写。
输出信息性质不同
- 多头注意力(Multi-Head Attention):
- 融合后的特征:由于其融合了不同来源的信息,输出的结果往往包含了两个或多个不同输入序列之间相互作用后的特征。例如在跨模态任务(如将文本和图像信息相结合)中,输出会包含文本和图像相互关联后的综合特征,用于后续的分类或生成等任务。
- 多头自注意力(Multi-Head Self-Attention):
- 序列内部关系的特征表示:输出的是输入序列自身内部关系的一种特征表示。例如在对一个文本序列进行词性标注任务时,输出的特征能够反映出句子内部单词之间的语法和语义关联,用于确定每个单词的词性。
3. Cross Attention
3.1 Cross attention简述
概念
交叉注意力是一种注意力机制,它允许一个序列(称为“查询”序列)中的元素关注另一个序列(称为“键-值”序列)中的元素,从而在两个序列之间建立联系。这种机制在多模态学习和跨数据源交互的场景中非常有用,例如在机器翻译中,解码器(查询序列)可以利用编码器(键-值序列)提供的上下文信息来生成更准确的翻译。
序列维度要求
为了进行交叉注意力计算,两个序列必须具有相同的维度。这是因为注意力机制的计算涉及到查询(Q)、键(K)和值(V)向量的点积操作,这些操作要求参与计算的向量具有相同的维度。具体来说,假设查询序列的维度为 d q d_q dq键序列的维度为 d k d_k dk,值序列的维度为 d v d_v dv,那么为了进行点积操作,必须满足 d q = d k d_q = d_k dq=dk。这样,查询向量和键向量的点积才能正确计算,进而生成注意力权重。
序列的多样性
两个序列可以是不同模态的数据,例如:
- 文本序列:一系列单词或子词的嵌入表示。在自然语言处理任务中,文本序列通常通过词嵌入(如Word2Vec、GloVe或BERT嵌入)转化为向量表示。
- 声音序列:音频信号的时序特征表示。在语音处理任务中,声音序列可以通过梅尔频谱图(Mel-spectrogram)或其他特征提取方法转化为向量表示。
- 图像序列:图像的像素或特征图的嵌入表示。在计算机视觉任务中,图像序列可以通过卷积神经网络(CNN)提取特征图,然后将这些特征图转化为向量表示。
3.2 交叉注意力的操作
交叉注意力就是一种“找朋友”的过程,有三样东西:查询(Q)、键(K)和值(V)。
查询(Q):
- 这就像是一个“问问题”的人。查询序列里有多少个元素,最后就会有多少个答案。
- 比如,查询序列有5个元素,那最后就会有5个输出,每个查询元素都会找到一个对应的答案。
键(K)和值(V):
- 这两个是“回答问题”的人。
- 键(K):它的作用是“标记”,用来告诉查询(Q):“嘿,我这儿有你需要的信息,跟我匹配一下吧!”
- 值(V):才是真正的“答案”。当查询(Q)找到和键(K)最匹配的部分时,就会从值(V)里拿到真正的信息。
过程:
- 查询(Q)会去和键(K)“打招呼”,看看谁和自己最像(也就是最相关)。
- 如果查询(Q)和某个键(K)很像,那它就会从对应的值(V)里拿到信息。
- 最后,查询(Q)把拿到的信息组合起来,就变成了自己的输出。
简单来说,交叉注意力就是让查询(Q)去另一个地方(键和值)找它需要的信息,通过“打招呼”(匹配)找到最相关的部分,然后把信息带回来。
4. Cross Attention 和 Self Attention 主要的区别
Cross Attention(交叉注意力)和 Self Attention(自注意力)是注意力机制中的两种不同类型,它们在信息交互的方式上有显著的区别。具体来说,主要区别体现在输入的来源和信息交互的对象上。
4.1 Self Attention 的定义与特点
自注意力(Self Attention) 是 Transformer 架构中的核心组成部分,其主要功能是捕捉序列内部各元素之间的依赖关系。
(1) 定义与工作原理:
自注意力机制允许序列中的每个元素(如单词或特征向量)与序列中的所有其他元素进行交互。这一过程涉及为每个元素生成查询(Query)、键(Key)和值(Value)三个向量。通过计算查询向量与所有键向量之间的点积,得到注意力权重,这些权重随后用于对值向量进行加权求和,从而为每个元素生成一个包含上下文信息的向量。
(2) 作用:
- 并行处理:Self Attention 允许模型并行处理序列中的所有元素,因为它不依赖于序列元素之间的顺序。
- 长距离依赖:它能够有效地捕捉长距离依赖,即使序列中的元素相隔很远,也能通过注意力机制建立联系。
- 参数效率:由于所有元素共享相同的权重,Self Attention 在参数数量上比传统的循环神经网络更为高效。
(3) 主要特点:
- 信息来源:自注意力的查询、键和值均来源于同一输入序列,这使得它能够从序列内部挖掘信息。
- 应用场景:自注意力适用于需要对序列内部元素之间的相互依赖进行建模的场景,如 NLP 中的句子编码、图像处理中的自相关特征提取等。
4.2 Cross Attention 的定义与特点
交叉注意力(Cross Attention) 是一种机制,它促进了两个不同序列之间的信息交互,广泛应用于多模态任务以及需要跨数据源交互的场景,例如在序列到序列模型(Seq2Seq)中,或者在图像与文本信息对齐的任务中。
(1) 定义与工作原理:
交叉注意力机制涉及两个序列的交互,其中一个序列提供查询(Query)向量,而另一个序列提供键(Key)和值(Value)向量。这种配置允许模型在两个序列之间建立联系,通过匹配和关联不同序列中的信息,从而在它们之间建立相关性。
(2) 作用:
- 跨模态学习:Cross Attention 在多模态学习中尤为重要,因为它能够将不同模态(如文本、图像、声音)的信息进行有效融合。
- 交互式解码:在序列到序列的任务中,Cross Attention 使得解码器能够利用编码器提供的上下文信息,从而更准确地生成目标序列。
- 灵活性:Cross Attention 提供了灵活性,因为它允许模型动态地关注另一个序列中与当前任务最相关的部分。
4.3 Self Attention 和 Cross Attention 的对比
特性 | Self Attention | Cross Attention |
---|---|---|
输入来源 | 在 Self Attention 中,Query、Key 和 Value 均来源于同一个序列。这意味着模型是在内部进行信息的自我比较和关联。 | 在 Cross Attention 中,Query 来自于一个序列,而 Key 和 Value 则来自于另一个不同的序列。这种配置允许模型在不同的数据源之间建立联系。 |
信息交互对象 | Self Attention 使得序列中的每个元素都能够关注序列中的所有其他元素,并基于这种关注来更新自己的表示。 | Cross Attention 则允许来自一个序列的元素(通过 Query)关注另一个序列中的所有元素(通过 Key 和 Value),从而实现跨序列的信息融合。 |
应用场景 | Self Attention 广泛应用于需要理解序列内部复杂依赖关系的场景,例如在自然语言处理中,用于捕捉句子中单词之间的相互作用。 | Cross Attention 适用于那些需要在不同序列之间建立联系的场合,如机器翻译中的编码器和解码器之间的交互,或者在多模态学习中,将文本信息与图像特征对齐。 |
特征捕捉 | Self Attention 能够捕捉并编码序列内部的全局依赖关系,使得每个位置的表示都融入了序列中其他位置的信息。 | Cross Attention 则专注于捕捉并编码不同序列之间的全局依赖关系,使得一个序列的表示能够反映另一个序列中的相关信息。 |
CrossAttention代码实现
import torch # 导入torch库
import torch.nn as nn # 导入torch.nn库
import torch.nn.functional as F # 导入torch.nn.functional库
# 定义CrossAttention类
class CrossAttention(nn.Module):
# 初始化函数
# dim: 模型的维度
# num_heads: 多头注意力的头数
# qkv_bias: qkv层是否有偏置
# qk_scale: qk层缩放因子
# attn_drop: 注意力矩阵的丢弃率
# proj_drop: 最后输出的丢弃率
# window_size: 窗口大小
# attn_mask: 注意力矩阵的掩码
# attn_head_dim: 注意力矩阵的头维度
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0,
proj_drop=0.0, window_size=None, attn_head_dim=None):
# 调用父类的初始化函数
super().__init__()
self.num_heads = num_heads # 多头注意力的头数
head_dim = dim // num_heads # 注意力矩阵的头维度
if attn_head_dim is not None: # 如果attn_head_dim不为None,则使用attn_head_dim作为注意力矩阵的头维度
head_dim = attn_head_dim
all_head_dim = head_dim * num_heads # 所有头的维度之和
self.scale = qk_scale or head_dim ** -0.5 # qk层缩放因子
# 定义qkv层
self.q = nn.Linear(dim, all_head_dim, bias=False) # q层
self.k = nn.Linear(dim, all_head_dim, bias=False) # k层
self.v = nn.Linear(dim, all_head_dim, bias=False) # v层
if qkv_bias: # 如果qkv层有偏置
self.q_bais = nn.Parameter(torch.zeros(all_head_dim)) # q层偏置
self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) # k层偏置
else:
self.q_bais = None # q层偏置
self.k_bias = None # k层偏置
self.v_bias = None # v层偏置
self.attn_drop = nn.Dropout(attn_drop) # 注意力矩阵的丢弃率
self.proj = nn.Linear(all_head_dim, dim) # 最后输出层
self.proj_drop = nn.Dropout(proj_drop) # 最后输出的丢弃率
# 前向传播函数
# x: 输入特征
# bool_masked_pos: 掩码位置的布尔值
# k: k层的权重
# v: v层的权重
def forward(self, x, bool_masked_pos=None, k = None, v = None):
B, N, C = x.shape # 输入特征的形状
N_k = k.shape[1] # k层的特征数
N_v = v.shape[1] # v层的特征数
q_bias, k_bias, v_bias = None, None, None # q层偏置,k层偏置,v层偏置
if self.q_bais is not None: # 如果q层有偏置
q_bias = self.q_bais # q层偏置
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) # k层偏置为0
v_bias = self.v_bias # v层偏置
q = F.linear(input=x, weight=self.q.weight, bias=q_bias) # q层
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # 实现多头注意力
k = F.linear(input=k, weight=self.k.weight, bias=k_bias) # k层
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # 实现多头注意力
v = F.linear(input=v, weight=self.v.weight, bias=v_bias) # v层
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # 实现多头注意力
q = q * self.scale # q层缩放
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) 注意力矩阵
attn = attn.softmax(dim=-1) # 注意力矩阵的softmax
attn = self.attn_drop(attn) # 注意力矩阵的丢弃率
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # 实现多头注意力
x = self.proj(x) # 最后输出层
x = self.proj_drop(x) # 最后输出的丢弃率
return x # 返回输出
# 设置相关的维度参数和输入张量示例
batch_size = 2 # 批次大小
dim = 64 # 特征维度
num_heads = 4 # 头的数量
seq_len_query = 10 # 查询序列长度
seq_len_key_value = 8 # 键值对序列长度
# 随机生成输入张量,模拟查询、键、值
query = torch.rand(batch_size, seq_len_query, dim) # 查询序列
key = torch.rand(batch_size, seq_len_key_value, dim) # 键序列
value = torch.rand(batch_size, seq_len_key_value, dim) # 值序列
# 实例化CrossAttention模块
cross_attention_module = CrossAttention(dim=dim, num_heads=num_heads)
# 进行前向传播计算
output = cross_attention_module(query, k=key, v=value)
print("输出结果的形状:", output.shape) # 输出结果的形状: torch.Size([2, 10, 64])