Attention原理详解


说到Attention可能大家想到的是最经典的Attention模型,实际上Attention有很多变种,Attention也有很多种计算方法,包括近来大火的self-attention也是Attention的一种。

下面我会由浅到深的介绍各种Attention机制,希望通过这篇博文能够让大家认识一个别样的Attention。

Attention是跟seq2seq模型一起用的,如果大家对seq2seq模型不了解的可以点击此链接了解其中的原理。

Attention模型

注意力机制可以分为两种:

一种叫做全局注意力,就是计算注意力的时候使用Encoder所有的隐藏状态。
还有一种叫局部注意力,使用Encoder一部分隐藏状态(目的是为了减少计算量,通常计算一个固定窗口内的部分)。

首先介绍一下什么是对齐,下面的内容会出现。

对齐

对齐的意思是把原始文本和对应的翻译进行匹配(就是哪个词对应着翻译成哪个词)。如下图所示:

在这里插入图片描述

模型介绍

看过上一篇博文的应该知道不加Attention的seq2seq模型是这样的:
在这里插入图片描述
Decoder中每一步输入的语义编码c都是相同的,也就意味着句子X中的每个单词对输出Y中的每一个单词的影响都是相同的,但实际上每个输出跟每个输入的关联度并不相同。

这样就会产生两个弊端:

一是语义向量无法完全表示整个序列的信息。
二是先输入的内容携带的信息会被后输入的信息稀释掉,或者说,被覆盖了,输入序列越长,这个现象就越严重。

这就使得在解码的时候一开始就没有获得输入序列足够的信息, 那么解码的准确度自然也就要打折扣。为了解决上面的弊端,就需要用到我们的Attention Model(注意力模型)来解决该问题。在机器翻译的时候,让生成词不是只能关注全局的语义编码向量c,而是增加了一个“注意力范围”。

通过这种方法,模型能够有选择地关注输入序列的有用部分,从而了解它们之间的对齐关系,有助于模型更好地处理输入较长的句子。

模型结构如下:
在这里插入图片描述
此模型中语义编码c1、c2、c3各不相同,y1、y2、y3的生成方法为:

在这里插入图片描述
比如输入的是英文句子:Tom chase Jerry,生成:“汤姆”,“追逐”,“杰瑞”。

注意力分配概率分布值的通用计算过程为:

在这里插入图片描述
当前输出词Yi针对某一个输入词j的注意力权重由当前的隐层Hi,以及输入词j的隐层状态(hj)共同决定;然后再接一个softmax得到0-1的概率值。即通过函数F(hj,Hi)来获得目标单词Yi和每个输入单词对应的对齐可能性。

那么整个翻译过程表示为:

在这里插入图片描述
每一个c会自动去选取与当前所要输出的y最合适的上下文信息。

具体来说,我们用 aij 衡量编码中第j阶段的hj和解码时第i阶段的相关性,最终Decoder中第i阶段的输入的上下文信息(语义编码) ci 就来自于所有 hj 对 aij 的加权和。

如下图所示,红色代表输入输出相关性很强,相应的权重会很大,比如“ I ”和“我”的相关性就很大,“China”和“中”、“国”两个输入的相关性很大。

在这里插入图片描述
输入的序列是“我爱中国”,因此,Encoder中的h1、h2、h3、h4就可以分别看做是“我”、“爱”、“中”、“国”所代表的信息。

在翻译成英语时,第一个上下文c1应该和“我”这个字最相关,因此对应的 a11就比较大,而相应的 a12 、a13 、 a14 就比较小。c2应该和“爱”最相关,因此对应的 a22 就比较大。最后的c3和h3、h4最相关,因此 a33 、 a34的值就比较大。

具体模型权重 aij 是如何计算出来的呢?

Attention整体流程

下面的内容会详细介绍Attention的整个流程:

通过为每个单词分配一个权重,注意力机制能够保证当前翻译的单词对原文各个单词的关注点不同(就是对照着原文翻译)。由于这个权重可能大于1,为了方便我们使用softmax进行归一化,得到归一化权重,然后计算Encoder隐藏状态和其对应归一化权重的加权和,得上下文向量c(语义编码)。

注意力层的实现可以分为6个步骤。

Step1 计算Encoder的隐藏状态和Decoder的隐藏状态

首先计算第一个解码器隐藏状态(红色)和所有可用的编码器隐藏状态(绿色)。下图中有4个编码器隐藏状态和当前解码器的隐藏状态。要想输出Decoder的第一个隐藏的状态,需要给Decoder一个初始状态和一个输入,例如采用Encoder的最后一个状态作为Decoder的初始状态,输入为0。

在这里插入图片描述

Step2 获取每个编码器隐藏状态对应的分数

计算Decoder的第一个隐藏状态和Encoder所有的隐藏状态的相关性,这里采用点积的方式(默认两个向量长度一样)。后面是一个计算示例。
在这里插入图片描述

 decoder_hidden = [10, 5, 10]
 encoder_hidden          score
 ---------------------
      [0, 1, 1]           15 (= 10×0 + 5×1 + 10×1, the dot product)
      [5, 0, 1]           60
      [1, 1, 0]           15
      [0, 5, 1]           35

在上面的例子中,对于编码器的隐藏状态[5, 0, 1],我们获得了较高的注意分值60。这意味着下一个要翻译的单词与此编码器的隐藏状态相关性很大。

Step3 通过softmax归一化分数

我们把得到的分数输入到softmax层进行归一化,归一化之后的分数(标量)加起来等于1,归一化后的分数代表注意力分配的权重 。
在这里插入图片描述

 encoder_hidden   score    score^
 -----------------------------
      [0, 1, 1]     15       0
      [5, 0, 1]     60       1
      [1, 1, 0]     15       0
      [0, 5, 1]     35       0

注意,经过softmax之后的分数score^,注意力的分配仅按预期放在了[5, 0, 1]上。实际上,这些数字不是二进制的,而是0到1之间的一个浮点数。

Step4 用每个编码器的隐藏状态乘以其softmax得分

通过将每个编码器的隐藏状态与其softmax之后的分数(标量)相乘,我们得到对齐向量 或标注向量。这正是对齐产生的机制。
在这里插入图片描述

 encoder_hidden   score  score^  alignment
 -----------------------------------------
 [0, 1, 1]          15     0     [0, 0, 0]
 [5, 0, 1]          60     1     [5, 0, 1]
 [1, 1, 0]          15     0     [0, 0, 0]
 [0, 5, 1]          35     0     [0, 0, 0]

在这里,我们看到除了[5, 0, 1]外,所有编码器隐藏状态的对齐都被降低到0,这是因为注意力得分较低。这意味着我们可以期望第一个被翻译的单词应该与输入单词中使用[5, 0, 1]嵌入所表示的单词匹配。

Step5 把所有对齐的向量加起来

对齐向量进行求和,生成上下文向量(语义编码)。上下文向量是前一步对齐向量的聚合信息。
在这里插入图片描述

 encoder_hidden   score   score^   alignment
 -------------------------------------------
   [0, 1, 1]       15       0      [0, 0, 0]
   [5, 0, 1]       60       1      [5, 0, 1]
   [1, 1, 0]       15       0      [0, 0, 0]
   [0, 5, 1]       35       0      [0, 0, 0]
 context = [0+5+0+0, 0+0+0+0, 0+1+0+0] = [5, 0, 1]

Step6 将上下文向量输入到Decoder中

输入的方式和模型有关。我们会在后面的例子中看到不同的模型是如何利用Encoder的上下文向量的。
在这里插入图片描述

完整过程

在这里插入图片描述

Attention到底是怎么工作的?

反向传播,意不意外?

反向传播将尽一切努力确保输出接近基本事实。这是通过改变RNNs和score函数(如果有的话)中的权重来实现的。这些权重将影响编码器的隐藏状态和解码器的隐藏状态,从而影响注意力得分。

Attention实例

通过上面的内容我们已经了解了seq2seq和seq2seq +注意模型,接下来会介绍3个基于seq2seq的NMT架构,来看一下具体是怎么用的。为了比较效果,文章给出了每个模型的双语评估研究(BLEU)分数 --用于评估生成的句子到参考句子的标准度量。

1. Neural Machine Translation by Learning to Jointly Align and Translate

这是注意力机制用于机器翻译的开篇之作。作者在《Neural Machine Translation by Jointly Learning to Align and Translate》的标题中使用了align这个词,意思是在训练模型的同时直接调整负责得分的权重。以下是关于这个架构需要注意的几点:

  1. 编码器是一个双向(正向+反向)门控循环单元(BiGRU)。解码器是一个GRU,其初始隐藏状态是从反向编码器GRU的最后隐藏状态修改而来的向量(未在下图中示出)。
  2. 注意层中的score函数是additive/concat。
  3. 下一个解码器时间步的输入是前一个解码器时间步(粉红色)的输出与当前时间步(深绿色)的上下文向量之间的拼接。
    在这里插入图片描述
    作者在WMT’14英语 - 法语数据集上获得了26.75的BLEU分数。

2. Effective Approaches to Attention-based Neural Machine Translation

论文链接:

Effective Approaches to Attention-based Neural Machine Translation

这篇论文的关键点如下:

  1. Encoder是两层的LSTM网络。 Decoder也一样,其初始隐藏状态是最后一个编码器的隐藏状态。
  2. 他们实验的score函数是(i) additive/concat, (ii) 点积,(iii) location-based,和(iv) “general”。
  3. 来自当前解码器时间步长的输出与来自当前时间步长的上下文向量之间的串联被输入到前馈神经网络,以给出当前解码器时间步长的最终输出(粉红色)。
    在这里插入图片描述
    在WMT’15英德测试中,该模型的BLEU得分为25.9。

3. Google的神经机器翻译(GNMT)

论文链接:

Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation

GNMT是前面两个例子的组合(深受第一个例子的启发)。

这篇论文的关键点如下:

  1. 编码器由8个LSTM组成,其中第一个LSTM是双向的(其输出是拼接起来的),来自连续层的输出之间存在残差连接(从第3层开始)。解码器是8个单向LSTM的独立堆叠。
  2. 使用的score函数是additive/concat,类似于第一个例子。
  3. 同样,就像第一个例子中一样,下一个解码器时间步的输入是前一个解码器时间步(粉红色)的输出和当前时间步(深绿色)的上下文向量之间的连接。

在这里插入图片描述
残差连接用曲线箭头表示。注意,LSTM单元格只显示隐藏状态和输入,它不显示单元格状态输入。

该模型在WMT’14英法测试上达到38.95 BLEU,在WMT’14英德测试上达到24.17 BLEU。

Score函数

下面是一些由Lilian Weng编辑的评分函数。

在这篇文章中提到了Additive/concat和点积。

score函数涉及点积运算(点积、余弦相似度等),其思想是度量两个向量之间的相似度。对于前馈神经网络评分函数,其思想是让模型在变换的同时学习对齐权值。
Summary of score functions
在这里插入图片描述
h代表encoder的隐藏状态,s代表decoder的隐藏状态。

Transformer

transformer源于论文《Attention Is All You Need》

它是seq2seq模型,但全程没有用到rnn或者cnn,它的Encoder部分和Decoder部分都使用了self-attention,self-attention可以说是Attention模型中的一大进步,为NLP工作者带来了极大的方便。

有关transformer的介绍在我的另一篇博客里,感兴趣的可以点此链接阅读。

<think>嗯,用户让我介绍一下Flash Attention,需要详细解释。首先,我得回忆一下自己对这个概念的理解。Flash Attention是一种用于加速Transformer模型中注意力机制计算的方法,对吧?那传统的注意力机制存在什么问题呢?可能是计算复杂度高,尤其是当序列长度增加时,内存和计算资源消耗大,导致训练和推理变慢。Flash Attention通过优化内存访问和计算顺序来解决这个问题,对吗? 接下来,我需要详细梳理Flash Attention原理。记得它主要用了两种技术:分块计算(Tiling)和重计算(Recomputation)。分块计算是将大的注意力矩阵分成小块,逐个处理,减少GPU显存的使用。而重计算则是在反向传播时重新计算某些中间结果,避免存储它们,从而节省内存。不过,用户可能需要更具体的解释,比如这些技术是如何具体实现的,为什么能提高效率。 另外,Flash Attention的核心优势是什么?高效的内存利用、更快的计算速度,还有支持更长的序列长度。这些都是用户可能关心的点。需要举一些例子,比如在训练大型模型时,显存占用减少,可以处理更长的文本序列,这对自然语言处理任务尤其重要。 可能还需要对比传统注意力机制的计算方式。传统方法需要计算整个注意力矩阵,存储中间结果,导致显存需求与序列长度的平方成正比。而Flash Attention通过分块和重计算,将显存需求降低到线性关系,这样在处理长序列时优势明显。 用户可能对具体实现细节感兴趣,比如分块的大小如何选择,重计算是如何在反向传播中应用的。这里可能需要解释一下算法流程,比如前向传播时如何分块计算注意力,反向传播时如何利用存储的某些中间变量和重计算梯度,以减少内存占用。 还要注意提到Flash Attention的变体,比如FlashAttention-2,在哪些方面做了改进,比如并行化和优化计算步骤,进一步提升速度和效率。可能用户会想知道不同版本之间的区别,以及实际应用中的选择建议。 另外,应用场景也是一个重点。比如在训练大规模Transformer模型(如GPT、BERT)时,使用Flash Attention可以显著减少显存占用,加快训练速度,或者在处理长文本、高分辨率图像时,能够处理更长的序列。这对于实际应用中的资源节省和效率提升非常重要。 可能还需要提到相关的论文和作者,比如Tri Dao等人提出的Flash Attention,增加权威性。同时,指出目前主流的深度学习框架是否已经集成了Flash Attention,比如PyTorch是否有对应的接口或库,方便用户直接调用。 最后,总结一下Flash Attention的重要性,强调它在当前大模型训练中的关键作用,以及未来可能的改进方向。确保用户理解其价值和实际应用中的好处。</think>Flash Attention 是一种**高效计算Transformer模型中注意力机制**的算法,由斯坦福大学团队于2022年提出。它通过优化内存访问和计算顺序,显著减少了传统注意力机制的计算开销和显存占用,成为训练大模型(如LLaMA、GPT-4)的重要技术。 --- ### **传统注意力机制的问题** Transformer的核心是自注意力(Self-Attention),其计算复杂度为 \(O(N^2)\)(N是序列长度),具体问题包括: 1. **显存占用高**:需存储中间矩阵(QK^T、Softmax结果等),显存需求与\(N^2\)成正比。 2. **计算效率低**:频繁读写显存(如GPU的HBM)导致延迟,尤其是长序列场景。 --- ### **Flash Attention的核心思想** 通过**分块计算(Tiling)**和**重计算(Recomputation)**,减少显存访问次数,实现近似计算但结果严格一致。 #### 1. **分块计算(Tiling)** - 将大的Q、K、V矩阵分割成小块(Tile),在SRAM(高速缓存)中逐块计算局部注意力,避免一次性存储整个QK^T矩阵。 - 通过动态更新最终输出(O)和Softmax归一化的统计量(如最大值和求和项),避免中间结果存储。 #### 2. **重计算(Recomputation)** - 反向传播时,不存储前向的中间矩阵(如QK^T),而是根据保留的Q、K、V块重新计算梯度。 - 牺牲少量计算时间换取显存大幅降低。 --- ### **算法优势** 1. **显存优化**:显存占用从 \(O(N^2)\) 降为 \(O(N)\)。 2. **速度提升**:利用GPU显存层次结构(HBM → SRAM),减少I/O操作,计算速度提升2-4倍。 3. **支持更长序列**:可处理数万长度的序列(如DNA分析、长文本)。 --- ### **具体实现步骤(简化版)** 1. **前向传播**: - 将Q、K、V分块,每次加载一块到SRAM。 - 计算局部QK^T,应用Softmax并加权得到局部注意力输出。 - 根据局部Softmax统计量(最大值、求和值),迭代更新全局输出和归一化因子。 2. **反向传播**: - 利用前向保留的分块Q、K、V和输出梯度,重新计算局部QK^T和Softmax,避免存储中间矩阵。 --- ### **FlashAttention-2 的改进** 2023年提出的FlashAttention-2进一步优化: 1. **减少非矩阵运算**:优化Reshape、转置等操作。 2. **并行化**:在序列长度和批次维度上增加并行度。 3. **块大小调整**:根据GPU特性(如A100的SRAM大小)调整分块策略。 --- ### **应用场景** 1. **大模型训练**:如GPT、BERT、LLaMA等,显存减少可支持更大批次或更长序列。 2. **长序列任务**:文本摘要、高分辨率图像生成(如ViT)、基因序列分析。 3. **边缘设备**:降低显存需求后,可在资源有限的设备上部署模型。 --- ### **代码示例(伪代码)** ```python def flash_attention(Q, K, V, block_size=64): O = zeros_like(V) for i in range(0, N, block_size): Qi = Q[i:i+block_size] for j in range(0, N, block_size): Kj, Vj = K[j:j+block_size], V[j:j+block_size] Sij = Qi @ Kj.T # 分块计算QK^T Pij = softmax(Sij) # 局部Softmax Oi += Pij @ Vj # 累加到输出块 return O ``` --- ### **总结** Flash Attention通过**算法与硬件的协同设计**,在数学等价的前提下,将注意力计算复杂度从显存瓶颈的\(O(N^2)\)降至\(O(N)\),成为现代大模型的基石技术。其后续变种(如FlashDecoding)还优化了推理阶段的性能。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值