文章目录
1 Self-Attention的概念
Self-Attention,也被称为scaled dot-product attention,是自然语言处理和深度学习领域的一个基本概念。它在机器翻译、文本摘要和情感分析等任务中发挥着关键作用。Self-Attention使模型能够在进行预测或捕获单词之间的依赖关系时权衡输入序列中不同部分的重要性。
在我们深入研究自我注意力之前,让我们先了解一下注意力这个更广泛的概念。想象一下阅读一份很长的文件;根据上下文,你的注意力自然地从一个词转移到另一个词。深度学习中的Attention mechanisms模仿了这种行为,允许模型选择性地专注于输入数据的特定元素,而忽略其他元素。
比如,在Encoder-Decoder框架中中,对于英-中机器翻译来说,Source是英文句子,Target是对应的翻译出的中文句子,每生成一个target单词都会考虑到Source和已生成的内容,接下来要生成的这个词就是通过计算待生成词与已生成词之间的Attention决定的。
Self-Attention中就是通过分析上下文,使模型能够辨别序列中单个词元的重要性,能够动态调整每个词元对下一个词元的影响程度。
注意
- Self-Attention存在于Transformer架构中的Encoder部分,也就是一串输入的词元中,每个词元的注意力分数都是跟上下文有关,会与前后词元进行计算。可以通过上图中看到。
- 而预测下一个词的时用到的注意力机制为:掩码注意力机制(Mask-Self-Attention),当前词元的注意力分数计算只与前文有关,只拿当前词元之前的词来计算。掩码注意力机制存在于Transformer中的Decoder中。
- 而位于Encoder-Decoder之间的,为交叉注意力机制,Cross-Self-Attention,键Key来自于Encoder,而Query来自于Decoder。
优缺点
-
优点:可以建立全局的依赖关系,扩大图像的感受野。相比于CNN,其感受野更大,可以获取更多上下文信息。
-
缺点:自注意力机制是通过筛选重要信息,过滤不重要信息实现的,这就导致其有效信息的抓取能力会比CNN小一些。这样是因为自注意力机制相比CNN,无法利用图像本身具有的尺度,平移不变性,以及图像的特征局部性(图片上相邻的区域有相似的特征,即同一物体的信息往往都集中在局部)这些先验知识,只能通过大量数据进行学习。这就导致自注意力机制只有在大数据的基础上才能有效地建立准确的全局关系,而在小数据的情况下,其效果不如CNN。
2 Self-Attention的原理
这里先给出Self-Attention的架构。
Q,K,V, and Self-Attention
- Query(Q): 可以将查询视为查找信息的元素。对于输入序列中的每个单词,计算一个查询向量。这些查询表示您希望在序列中关注的内容。
- Key(K):用于识别和定位序列中的重要元素。与查询一样,为每个单词计算键向量。
- Value(V): 对于每个单词,计算一个值向量。这些向量包含我们在确定序列中单词的重要性时要考虑的内容。
-Attention Scores: 准备好QKV后,计算序列中每对单词的注意分数。查询和关键字之间的关注分数量化了它们的兼容性或相关性。 - Weighted Aggregation:最后,将注意力得分作为权重,对值向量进行加权聚合。这种聚合产生自关注输出,表示输入序列的增强的和上下文知情的表示。
计算公式
Self-Attention的核心是Query(Q),Key(K),Value(V)。
通过对注意力机制的学习我们知道,对于注意力机制来说,键值对形式的Attention计算公式如下:
在Self-Attention中,公式中的K、Q、V表示如下图所示,可以看出QKV的来源都是输入矩阵X与矩阵的乘积,本质上都是X的线性变换(代码上就是分别经过了三个线性层),这也是为什叫做自注意力机制的原因。
从上式可以看出其计算过程为:首先,计算矩阵Q和K每一行向量的内积,为了防止内积过大,除以d_k的平方根;其次,**使用Softmax对上述内积的结果进行归一化;**最后得到Softmax矩阵之后和V相乘,得到最终的输出。
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, dim_q, dim_k, dim_v):
super(SelfAttention, self).__init__()
self.dim_q = dim_q
self.dim_k = dim_k
11. self.dim_v = dim_v
# 定义线性变换函数
self.linear_q = nn.Linear(dim_q, dim_k, bias=False)
self.linear_k = nn.Linear(dim_q, dim_k, bias=False)
self.linear_v = nn.Linear(dim_q, dim_v, bias=False)
self._norm_fact = 1 / math.sqrt(dim_k)
def forward(self, x):
# x: batch, n, dim_q
# 根据文本获得相应的维度
batch, n, dim_q = x.shape
# 如果条件为 True,则程序继续执行;如果条件为 False,则程序抛出一个 AssertionError 异常,并停止执行。