【总结】Attention详解

本文深入探讨了Seq2Seq模型及其局限性,引出了Attention机制的重要性。Attention允许模型根据任务目标关注输入序列的不同部分,解决了信息压缩的问题。Self-Attention简化了这一过程,Multi-head Attention则进一步增强了模型的能力。尽管Attention模型在处理序列信息时表现出色,但顺序信息的丢失是其一大挑战,Position Embedding被用来引入位置信息。总结来说,Attention机制在自然语言处理中起到了关键作用,尤其在机器翻译、文本摘要等领域。

【总结】Attention详解


参考资料

1)

2)

3) 雷锋网

4)李宏毅老师

论文原文:Attention is all you need

1)Seq2Seq问题

序列标注的问题中,要求输入和输出序列等长,即N vs N 结构;但是在多数应用中,比如机器翻译,其输入和输出不等长,为 N VS M结构,无法用序列标注来做,这个时候要用到称之为Seq2Seq模型,或者叫Encoder-Decoder模型;这里先介绍N vs M 结构怎么操作;

1)首先,我们有输入序列, [ x 1 , x 2 , . . . , x N ] [x_1, x_2, ..., x_N] [x1,x2,...,xN] 长度为N的时间序列;Encoder-Decoder框架先将其编码为一个融合上下文语义的向量 c c c ; 这部分操作使用一个循环网络就可以实现,最简单的方法就是把RNN的最后一个隐状态赋值给c,还可以对最后的隐状态做一个变换得到c,也可以对所有的隐状态做变换。 这个RNN叫做Encoder-Decoder模型的编码器encoder;

在这里插入图片描述

2)拿到向量 c c c 之后,就用另一个RNN网络对其进行解码,这部分任务类似于使用图像生成文本,这部分RNN网络成为解码器Decoder; 具体做法有两种,1)把 c c c 当做RNN的初始状态输入,2)把 c c c当做每一步的输入;

在这里插入图片描述

在这里插入图片描述

由于这种Encoder-Decoder结构不限制输入和输出的序列长度,因此应用的范围非常广泛,比如:

  • 机器翻译。Encoder-Decoder的最经典应用,事实上这一结构就是在机器翻译领域最先提出的
  • 文本摘要。输入是一段文本序列,输出是这段文本序列的摘要序列。
  • 阅读理解。将输入的文章和问题分别编码,再对其进行解码得到问题的答案。
  • 语音识别。输入是语音信号序列,输出是文字序列。

2)Attention机制(思想)

Attention和Seq2Seq并没有直接联系,只是现在的Seq2Seq模型都是会加上Attention机制,所以介绍Attention之前先介绍一下Seq2Seq问题;通过第一章的介绍,我们知道了输入序列需要编码成包含全部序列信息的语义向量 c c c; 分析一下你会发现,这里就存在一个向量压缩的过程,序列的长度就成了限制模型性能的瓶颈 ;语义向量 c c c 无法对序列全部的时间步进行表示,想要在有限的向量容量里面表示更多的全文信息,我们的模型需要学会取舍:丢弃一些对语义不重要的信息,强化并保留关键的信息;

Attention的思想如同它的名字一样,就是“注意力”,在预测结果时把注意力放在不同的特征上。 举个例子: 比如我在预测一句话:“我妈今天做的这顿饭真好吃”的情感时,如果只预测正向还是负向,那真正影响结果的只有“真好吃”这三个字,前面说的“我妈今天做的这顿饭”基本没什么用,如果是直接对token embedding进行平均去求句子表示会引入不少噪声。所以引入attention机制,让我们可以根据任务目标赋予输入token不同的权重,理想情况下前半句的权重都在0.0及,后三个字则是“0.3, 0.3, 0.3”,在计算句子表示时就变成了:

最终表示 = 0.01x我+0.01x妈+0.01x今+0.01x天+0.01x做+0.01x的+0.01x这+0.01x顿+0.02x饭+0.3x真+0.3x好+0.3x吃

介绍到这里,Attention的核心思想就出来了,那就是加权求和

3)Attention在Seq2Seq中的具体操作

回到Seq2Seq的问题上来,在解码部分,我们需要对向量 c c c进行解码,所有时间步直接使用一样的向量 c c c,无法满足每一个时间步的解码要求; 更好的做法是通过在每个时间输入不同的c来解决这个问题;

在这里插入图片描述

每一个 c c c会自动去选取与当前所要输出的y最合适的上下文信息。具体来说: 对解码器的每一个时间步,我们都需要对输入序列进行一次Attention计算,得到不同的加权求和结果; 以机器翻译为例(将中文翻译成英文):

完全图解RNN、RNN变体、Seq2Seq、Attention机制

输入的序列是“我爱中国”,因此,Encoder中的h1、h2、h3、h4就可以分别看做是“我”、“爱”、“中”、“国”所代表的信息。在翻译成英语时,第一个上下文c1应该和“我”这个字最相关,因此对应的 a 11 a_{11} a11就比较大,而相应的 a 12 , a 13 , a 14 a_{12}, a_{13}, a_{14} a12,a13,a14就比较小。c2应该和“爱”最相关,因此对应的 a 22 a_{22} a22就比较大。最后的c3和h3、h4最相关,因此 a 33 , a 34 a_{33}, a_{34} a33,a34的值就比较大。

🍀 那么,还剩最后一个问题,这些权重怎么计算得到的?

事实上, a i j a_{ij} aij同样是从模型中学出的,它实际和Decoder的第 i − 1 i-1 i1阶段的隐状态、Encoder第 j j j 个阶段的隐状态有关。

同样还是拿上面的机器翻译举例, a1j 的计算(此时箭头就表示 h ′ h' h h j h_j hj 同时做变换):

完全图解RNN、RNN变体、Seq2Seq、Attention机制

a2j 的计算:

完全图解RNN、RNN变体、Seq2Seq、Attention机制

把上面的问题抽象一下; 通常我们会将输入分为query(Q), key(K), value(V)三种: Decoder(解码)的输出为Q,Encoder(编码)的输出为K和V;

权重由Q和K计算得到:
a = s o f t m a x ( f ( Q , K ) ) a = softmax(f(Q, K )) a=softmax(f(Q,K))

这里QK的具体运算f有多种方法,常见的有如下两种:

f ( Q , K ) = t a n h ( W 1 Q + W 2 K ) f(Q, K ) = tanh(W_1Q+W_2K) f(Q,K)=tanh(W1Q+W2K)

f ( Q , K ) = Q K T f(Q,K) = QK^T f(Q,K)=QKT

用如下例子说明:

img

其中 h i h^i hi是编码器每个时间步的输出, z j z^j zj是解码器每个step的输出,计算步骤是这样的:

  1. 先对输入进行Encode,得到 [ h 1 , h 2 , h 3 , h 4 ] [h^1, h^2,h^3,h^4] [h1,h2,h3,h4] ;
  2. 开始Decode,先用固定的start token也就是 $z^0 作 为 ∗ ∗ Q ∗ ∗ , 去 和 每 个 作为**Q**,去和每个 Qh^i$ (同时作为K和V)去计算attention的权重,得到加权的$c^0 $
  3. 用 $c^0 $作为解码的RNN输入(同时还有上一步的 $z^0 $),得到 $z^1 $ 并预测出第一个词是machine
  4. 再继续预测的话,就是用 z 1 z^ 1 z1 作为Q去求attention:

img

思想就是根据当前解码“状态”判断输入序列的权重分布。

4)self-attention

关于attention有很多应用,在非seq2seq任务中,比如文本分类,或者其他分类问题,会通过self attention来使用attention。这个方法思想很简单,而且计算成本相对来说不高,强烈推荐。具体来说就是: Q=K=V;

1、通过序列本身来计算序列当中每一个词的权重;

2、然后通过softmax归一化,与序列相乘,得到加权之后的向量;

这种方法中,句子中的每个词都能与句子中任意距离的其他词建立一个敏感的关系,可以说在一定程度上提升了之前所说的CNN和RNN对于长距离语义依赖建模能力的上限。 加上其操作简单,推荐使用;

5)Multi-head Attention

这个是Google提出的新概念,是Attention机制的完善。不过从形式上看,它其实就再简单不过了,就是把Q,K,V通过参数矩阵映射一下,然后再做Attention,把这个过程重复做h次,结果拼接起来就行了 ;

Multi-Head Attention

6)总结

优点:

  • 在输出序列与输入序列“顺序”不同的情况下表现较好,如翻译、阅读理解
  • 相比RNN可以编码更长的序列信息

缺点:

  • 对序列顺序不敏感: 这样的模型并不能捕捉序列的顺序!换句话说,如果将K,V 按行打乱顺序(相当于句子中的词序打乱),那么Attention的结果还是一样的。
  • 通常和RNN结合使用,不能并行化

7)分析

RNN编码器编码的语义向量本身是有序列的顺序信息的,加入attention之后弱化了该部分功能,在解码端,每一个时间步都能获取有利于解码输出的最佳权重比例,那么即便输入序列打乱顺序,对结果也没有影响;这个问题就很严重了;我丢失了重要的顺序信息; 学习不到顺序信息,那么效果将会大打折扣(比如机器翻译中,有可能只把每个词都翻译出来了,但是不能组织成合理的句子)。

Google给出的解决办法是:Position Embedding: 也就是“位置向量”,将每个位置编号,然后每个编号对应一个向量,通过结合位置向量和词向量,就给每个词都引入了一定的位置信息,这样Attention就可以分辨出不同位置的词了。

<think>嗯,用户让我介绍一下Flash Attention,需要详细解释。首先,我得回忆一下自己对这个概念的理解。Flash Attention是一种用于加速Transformer模型中注意力机制计算的方法,对吧?那传统的注意力机制存在什么问题呢?可能是计算复杂度高,尤其是当序列长度增加时,内存和计算资源消耗大,导致训练和推理变慢。Flash Attention通过优化内存访问和计算顺序来解决这个问题,对吗? 接下来,我需要详细梳理Flash Attention的原理。记得它主要用了两种技术:分块计算(Tiling)和重计算(Recomputation)。分块计算是将大的注意力矩阵分成小块,逐个处理,减少GPU显存的使用。而重计算则是在反向传播时重新计算某些中间结果,避免存储它们,从而节省内存。不过,用户可能需要更具体的解释,比如这些技术是如何具体实现的,为什么能提高效率。 另外,Flash Attention的核心优势是什么?高效的内存利用、更快的计算速度,还有支持更长的序列长度。这些都是用户可能关心的点。需要举一些例子,比如在训练大型模型时,显存占用减少,可以处理更长的文本序列,这对自然语言处理任务尤其重要。 可能还需要对比传统注意力机制的计算方式。传统方法需要计算整个注意力矩阵,存储中间结果,导致显存需求与序列长度的平方成正比。而Flash Attention通过分块和重计算,将显存需求降低到线性关系,这样在处理长序列时优势明显。 用户可能对具体实现细节感兴趣,比如分块的大小如何选择,重计算是如何在反向传播中应用的。这里可能需要解释一下算法流程,比如前向传播时如何分块计算注意力,反向传播时如何利用存储的某些中间变量和重计算梯度,以减少内存占用。 还要注意提到Flash Attention的变体,比如FlashAttention-2,在哪些方面做了改进,比如并行化和优化计算步骤,进一步提升速度和效率。可能用户会想知道不同版本之间的区别,以及实际应用中的选择建议。 另外,应用场景也是一个重点。比如在训练大规模Transformer模型(如GPT、BERT)时,使用Flash Attention可以显著减少显存占用,加快训练速度,或者在处理长文本、高分辨率图像时,能够处理更长的序列。这对于实际应用中的资源节省和效率提升非常重要。 可能还需要提到相关的论文和作者,比如Tri Dao等人提出的Flash Attention,增加权威性。同时,指出目前主流的深度学习框架是否已经集成了Flash Attention,比如PyTorch是否有对应的接口或库,方便用户直接调用。 最后,总结一下Flash Attention的重要性,强调它在当前大模型训练中的关键作用,以及未来可能的改进方向。确保用户理解其价值和实际应用中的好处。</think>Flash Attention 是一种**高效计算Transformer模型中注意力机制**的算法,由斯坦福大学团队于2022年提出。它通过优化内存访问和计算顺序,显著减少了传统注意力机制的计算开销和显存占用,成为训练大模型(如LLaMA、GPT-4)的重要技术。 --- ### **传统注意力机制的问题** Transformer的核心是自注意力(Self-Attention),其计算复杂度为 \(O(N^2)\)(N是序列长度),具体问题包括: 1. **显存占用高**:需存储中间矩阵(QK^T、Softmax结果等),显存需求与\(N^2\)成正比。 2. **计算效率低**:频繁读写显存(如GPU的HBM)导致延迟,尤其是长序列场景。 --- ### **Flash Attention的核心思想** 通过**分块计算(Tiling)**和**重计算(Recomputation)**,减少显存访问次数,实现近似计算但结果严格一致。 #### 1. **分块计算(Tiling)** - 将大的Q、K、V矩阵分割成小块(Tile),在SRAM(高速缓存)中逐块计算局部注意力,避免一次性存储整个QK^T矩阵。 - 通过动态更新最终输出(O)和Softmax归一化的统计量(如最大值和求和项),避免中间结果存储。 #### 2. **重计算(Recomputation)** - 反向传播时,不存储前向的中间矩阵(如QK^T),而是根据保留的Q、K、V块重新计算梯度。 - 牺牲少量计算时间换取显存大幅降低。 --- ### **算法优势** 1. **显存优化**:显存占用从 \(O(N^2)\) 降为 \(O(N)\)。 2. **速度提升**:利用GPU显存层次结构(HBM → SRAM),减少I/O操作,计算速度提升2-4倍。 3. **支持更长序列**:可处理数万长度的序列(如DNA分析、长文本)。 --- ### **具体实现步骤(简化版)** 1. **前向传播**: - 将Q、K、V分块,每次加载一块到SRAM。 - 计算局部QK^T,应用Softmax并加权得到局部注意力输出。 - 根据局部Softmax统计量(最大值、求和值),迭代更新全局输出和归一化因子。 2. **反向传播**: - 利用前向保留的分块Q、K、V和输出梯度,重新计算局部QK^T和Softmax,避免存储中间矩阵。 --- ### **FlashAttention-2 的改进** 2023年提出的FlashAttention-2进一步优化: 1. **减少非矩阵运算**:优化Reshape、转置等操作。 2. **并行化**:在序列长度和批次维度上增加并行度。 3. **块大小调整**:根据GPU特性(如A100的SRAM大小)调整分块策略。 --- ### **应用场景** 1. **大模型训练**:如GPT、BERT、LLaMA等,显存减少可支持更大批次或更长序列。 2. **长序列任务**:文本摘要、高分辨率图像生成(如ViT)、基因序列分析。 3. **边缘设备**:降低显存需求后,可在资源有限的设备上部署模型。 --- ### **代码示例(伪代码)** ```python def flash_attention(Q, K, V, block_size=64): O = zeros_like(V) for i in range(0, N, block_size): Qi = Q[i:i+block_size] for j in range(0, N, block_size): Kj, Vj = K[j:j+block_size], V[j:j+block_size] Sij = Qi @ Kj.T # 分块计算QK^T Pij = softmax(Sij) # 局部Softmax Oi += Pij @ Vj # 累加到输出块 return O ``` --- ### **总结** Flash Attention通过**算法与硬件的协同设计**,在数学等价的前提下,将注意力计算复杂度从显存瓶颈的\(O(N^2)\)降至\(O(N)\),成为现代大模型的基石技术。其后续变种(如FlashDecoding)还优化了推理阶段的性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值