引言
当我决定深入研究Transformer架构时,我常常在阅读或观看网上教程时感到挫败,因为总觉得它们缺少一些关键内容:
- Tensorflow或Pytorch的官方教程使用了自己的API,保持在高层次,这迫使我不得不深入它们的代码库去了解底层实现。这非常耗时,而且阅读成千上万行代码也并不容易。
- 其他使用自定义代码的教程(文章末尾有链接)通常对用例过于简化,没有涉及诸如可变长度序列批处理的掩码处理等概念。
因此,我决定自己编写一个Transformer,以确保我理解这些概念,并能够将其应用于任何数据集。在本文中,我们将采用一种系统的方法,逐层、逐块地实现一个Transformer。
那么,为什么不使用TF/Pytorch的实现呢?本文的目的是教育性的,我并不打算超越Pytorch或Tensorflow的实现。我确实认为Transformer的理论和背后的代码并不容易理解,这也是我希望通过这个一步步的教程,让你更好地掌握这些概念,并在以后编写自己的代码时感到更自在的原因。从头开始构建自己的Transformer的另一个原因是,它将让你完全理解如何使用上述API。如果我们查看Pytorch中Transformer类的forward()
方法实现,你会看到很多晦涩的关键词,比如:
如果你已经熟悉这些关键词,那么你可以愉快地跳过这篇文章。否则,本文将带你逐一理解这些关键词及其背后的概念。
Transformer简短介绍
如果你听说过ChatGPT或Gemini,那么你已经遇到过Transformer了。事实上,ChatGPT中的“T”代表Transformer。该架构首次提出是在2017年,由谷歌研究人员在论文《Attention is All You Need》中提出。这是一种非常革命性的架构,因为以前的模型(用于序列到序列学习的机器翻译、语音转文本等)依赖于计算上昂贵的RNN,它们必须逐步处理序列。而Transformer只需要一次性地查看整个序列,将时间复杂度(Sequential Operations)从 O(n) 降低到 O(1)。
对于其他的复杂度指标,我们先在此做最简单的理解,相信你在看完本篇文章后,将会有更深刻的认识。Complexity per Layer代表每层(主要)的计算复杂度,对于自注意力(Self-Attention),假设输入序列的长度为n,特征维度为d,首先会生成3个(Q, K, V)长度为n,特征维度为d的嵌入矩阵,Q与K(的转置)做矩阵乘法得到n x n维度的注意力图,代表每个位置与其他位置的注意力权重,该步骤复杂度为,然后该注意力图与V做矩阵乘法得到更新后的
维度的输出序列特征表示,当然还有后面的线性变换层等操作,但主要的复杂度来自注意力图的计算。Sequential Operations代表每层需要依次进行的操作数,顺序操作数为 O(1),意味着这些操作可以并行化;顺序操作数为 O(n),代表需要依次进行n次操作。Maximum Path Length表示信息在层中传播所需经过的最大路径长度,Self-Attention只需操作一次就可以全局更新每个位置上的特征,而卷积需要
次操作(受限于卷积窗口的大小)才可以更新每个位置,尽管这些卷积操作可以并行计算。
通过这个表格可以看出,不同类型的层在计算复杂度、并行化能力和信息传播速度上的区别。自注意力层(Self-Attention)在并行化和信息传播速度上有明显优势,但计算复杂度较高。而循环层(Recurrent)虽然计算复杂度较低,但顺序操作数和最大路径长度较长,限制了并行化能力和信息传播速度。卷积层(Convolutional)在这三者之间,综合了计算复杂度和并行化能力。
Transformer的总体结构如下:
如果你每次看到这个图都不知所云,我可以先告诉你几点:
第一:在很多时候,包括你看到的很多文章里的模型,只用到了上图的编码器部分,用于提取输入序列的特征(如Vision Transformer, VIT)。而解码器主要用于构建完整的生成模型,用于像语言翻译、文字续写之类的生成任务。如果你的目的是分类这样的判别式任务,你大概率不会接触到Transformer解码器。
第二:假设你的任务是一个文字翻译任务,当你在训练网络的时候,你的input会是一个经编码后的序列,假设没编码前的序列是“我很帅”,那么,你的解码器的输入会是GT(“I am handsome”),然后一次性预测GT中每个位置的词,当然这里会用到Masked Attention来保证解码器在预测“I”时不会看到“ am handsome”。这在训练时是并行执行的,与推理时的自回归方式不同。
你需要明白,当我们说Transformer的时候,就是在说上面这个结构,他包含一个编码器和一个解码器,而不仅仅是编码器,更不是在说注意力机制。这点很重要,特别是当你来自计算机视觉领域,而不是自然语言处理。(手动狗头)
好,现在让我们进入正题。
多头自注意力
我们将实现的第一个块实际上是Transformer中最重要的部分,它被称为多头注意。让我们看看它在整个体系结构中的位置
注意是一种机制,实际上并不是Transformer所特有的,它已经在RNN sequence-to-sequence模型中使用。
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim=256, num_heads=4):
"""
input_dim: Dimensionality of the input.
num_heads: The number of attention heads to split the input into.
"""
super(MultiHeadAttention, self).__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads"
self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output layer
def check_sdpa_inputs(self, x):
assert x.size(1) == self.num_heads, f"Expected size of x to be ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads}), got {x.size()}"
assert x.size(3) == self.hidden_dim // self.num_heads
def scaled_dot_product_attention(self, query, key, value,
attention_mask=None, key_padding_mask=None):
"""
query : (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)
key : (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
value : (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
attention_mask : (query_sequence_length, key_sequence_length)
key_padding_mask : (sequence_length, key_sequence_length)
"""
self.check_sdpa_inputs(query)
self.check_sdpa_inputs(key)
self.check_sdpa_inputs(value)
d_k = query.size(-1)
tgt_len, src_len = query.size(-2), key.size(-2)
# logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Attention mask here
if attention_mask is not None:
if attention_mask.dim() == 2:
assert attention_ma