2024年1月14日
自注意力是 LLM 的一大核心组件。对大模型及相关应用开发者来说,理解自注意力非常重要。近日,Ahead of AI 杂志运营者、机器学习和 AI 研究者 Sebastian Raschka 发布了一篇文章,介绍并用代码从头实现了 LLM 中的自注意力、多头注意力、交叉注意力和因果注意力。
这篇文章将介绍 Transformer 架构以及 GPT-4 和 Llama 等大型语言模型(LLM)中使用的自注意力机制。自注意力等相关机制是 LLM 的核心组件,因此如果想要理解 LLM,就需要理解它们。
不仅如此,这篇文章还会介绍如何使用 Python 和 PyTorch 从头开始编写它们的代码。在我看来,从头开始写算法、模型和技术的代码是一种非常棒的学习方式!
考虑到文章篇幅,我假设读者已经知道 LLM 并且已经对注意力机制有了基本了解。本文的目标和重点是通过 Python 和 PyTorch 编程过程来理解注意力机制的工作方式。
介绍自注意力
自注意力自在原始 Transformer 论文《Attention Is All You Need》中被提出以来,已经成为许多当前最佳的深度学习模型的一大基石,尤其是在自然语言处理(NLP)领域。由于自注意力已经无处不在,因此理解它是很重要的。

原始 Transformer 架构,来自论文 https://arxiv.org/abs/1706.03762
究其根源,深度学习中的「注意力(attention)」概念可以追溯到一种用于帮助循环神经网络(RNN)处理更长序列或句子的技术。举个例子,假如我们需要将一个句子从一种语言翻译到另一种语言。逐词翻译的操作方式通常不可行,因为这会忽略每种语言独有的复杂语法结构和习惯用语,从而导致出现不准确或无意义的翻译结果。

上图是不正确的逐词翻译,下图是正确的翻译结果
为了解决这个问题,研究者提出了注意力机制,让模型在每个时间步骤都能访问所有序列元素。其中的关键在于选择性,也就是确定在特定上下文中哪些词最重要。2017 年时,Transformer 架构引入了一种可以独立使用的自注意力机制,从而完全消除了对 RNN 的需求。

来自论文《Attention is All You Need》的插图,展示了 making 这个词对其它词的依赖或关注程度,其中的颜色代表注意力权重的差异。
对于自注意力机制,我们可以这么看:通过纳入与输入上下文有关的信息来增强输入嵌入的信息内容。换句话说,自注意力机制让模型能够权衡输入序列中不同元素的重要性,并动态调整它们对输出的影响。这对语言处理任务来说尤其重要,因为在语言处理任务中,词的含义可能会根据句子或文档中的上下文而改变。
请注意,自注意力有很多变体。人们研究的一个重点是如何提高自注意力的效率。然而,大多数论文依然是实现《Attention Is All You Need》论文中提出的原始的缩放点积注意力机制(scaled-dot product attention mechanism),因为对于大多数训练大规模 Transformer 的公司来说,自注意力很少成为计算瓶颈。
因此,本文重点关注的也是原始的缩放点积注意力机制(称为自注意力),毕竟这是实践中最流行和使用范围最广泛的注意力机制。但是,如果你对其它类型的注意力机制感兴趣,可以参阅其它论文:
-
Efficient Transformers: A Survey:https://arxiv.org/abs/2009.06732
-
A Survey on Efficient Training of Transformers:https://arxiv.org/abs/2302.01107
-
FlashAttention:https://arxiv.org/abs/2205.14135
-
FlashAttention-v2:https://arxiv.org/abs/2307.08691
对输入句子进行嵌入操作
开始之前,我们先考虑以下输入句子:「Life is short, eat dessert first」。我们希望通过自注意力机制来处理它。类似于其它类型的用于处理文本的建模方法(比如使用循环神经网络或卷积神经网络),我们首先需要创建一个句子嵌入(embedding)。
为了简单起见,这里我们的词典 dc 仅包含输入句子中出现的词。在真实世界应用中,我们会考虑训练数据集中的所有词(词典的典型大小在 30k 到 50k 条目之间)。
输入:
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s
in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)
输出:
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
接下来,我们使用这个词典为每个词分配一个整数索引:
输入:
import torch
sentence_int = torch.tensor(
[dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)
输出:
tensor([0, 4, 5, 2, 1, 3])
现在,使用输入句子的整数向量表征,我们可以使用一个嵌入层来将输入编码成一个实数向量嵌入。这里,我们将使用一个微型的 3 维嵌入,这样一来每个输入词都可表示成一个 3 维向量。
请注意,嵌入的大小范围通常是从数百到数千维度。举个例子,Llama 2 的嵌入大小为 4096。这里之所以使用 3 维嵌入,是为了方便演示。这让我们可以方便地检视各个向量的细节。
由于这个句子包含 6 个词,因此最后会得到 6×3 维的嵌入:
输入:
vocab_size = 50_000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)
输出:
tensor([[ 0.3374, -0.1778, -0.3035],
[ 0.1794, 1.8951, 0.4954],
[ 0.2692, -0.0770, -1.0205],
[-0.2196, -0.3792, 0.7671],
[-0.5880, 0.3486, 0.6603],
[-1.1925, 0.6984, -1.4097]])
torch.Size([6, 3])
定义权重矩阵
现在开始讨论广被使用的自注意力机制,也称为缩放点积注意,这是 Transformer 架构不可或缺的组成部分。
自注意力使用了三个权重矩阵,分别记为 W_q、W_k 和 W_v;它们作为模型参数,会在训练过程中不断调整。这些矩阵的作用是将输入分别投射成序列的查询、键和值分量。
相应的查询、键和值序列可通过权重矩阵 W 和嵌入的输入 x 之间的矩阵乘法来获得:
-
查询序列:对于属于序列 1……T 的 i,有 q⁽ⁱ⁾=x⁽ⁱ⁾W_q
-
键序列:对于属于序列 1……T 的 i,有 k⁽ⁱ⁾=x⁽ⁱ⁾W_k
-
值序列:对于属于序列 1……T 的 i,有 v⁽ⁱ⁾=x⁽ⁱ⁾W_v
-
索引 i 是指输入序列中的 token 索引位置,其长度为 T。

通过输入 x 和权重 W 计算查询、键和值向量
这里,q⁽ⁱ⁾ 和 k⁽ⁱ⁾ 都是维度为 d_k 的向量。投射矩阵 W_q 和 W_k 的形状为 d × d_k,而 W_v 的形状是 d × d_v。
(需要注意,d 表示每个词向量 x 的大小。)
由于我们要计算查询和键向量的点积,因此这两个向量的元素数量必须相同(d_q=d_k)。很多 LLM 也会使用同样大小的值向量,也即 d_q=d_k=d_v。但是,值向量 v⁽ⁱ⁾ 的元素数量可以是任意值,其决定了所得上下文向量的大小。
在接下来的代码中,我们将设定 d_q=d_k=2,而 d_v=4。投射矩阵的初始化如下:输入:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))
(类似于之前提到的词嵌入,实际应用中的维度 d_q、d_k、d_v 都大得多,这里使用小数值是为了方便演示。)
计算非归一化的注意力权重
现在假设我们想为第二个输入元素计算注意力向量 —— 也就是让第二个输入元素作为这里的查询:

对于接下来的章节,我们将重点关注第二个输入 x⁽²⁾。
写成代码就是这样:
输入:
x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 =

本文详细介绍了自注意力、多头注意力、交叉注意力和因果注意力在大型语言模型(LLM)如Transformer、GPT-4和Llama中的应用。通过Python和PyTorch代码,阐述了自注意力的工作原理,包括输入嵌入、权重矩阵计算、注意力权重的计算与归一化,以及多头注意力和因果自注意力的实现。自注意力是LLM的核心组件,有助于模型理解和生成上下文相关的输出。
最低0.47元/天 解锁文章
2399

被折叠的 条评论
为什么被折叠?



