从头实现一个完整的Transformer模型

引言

        当我决定深入研究Transformer架构时,我常常在阅读或观看网上教程时感到挫败,因为总觉得它们缺少一些关键内容:

  • Tensorflow或Pytorch的官方教程使用了自己的API,保持在高层次,这迫使我不得不深入它们的代码库去了解底层实现。这非常耗时,而且阅读成千上万行代码也并不容易。
  • 其他使用自定义代码的教程(文章末尾有链接)通常对用例过于简化,没有涉及诸如可变长度序列批处理的掩码处理等概念。

因此,我决定自己编写一个Transformer,以确保我理解这些概念,并能够将其应用于任何数据集。在本文中,我们将采用一种系统的方法,逐层、逐块地实现一个Transformer。

        那么,为什么不使用TF/Pytorch的实现呢?本文的目的是教育性的,我并不打算超越Pytorch或Tensorflow的实现。我确实认为Transformer的理论和背后的代码并不容易理解,这也是我希望通过这个一步步的教程,让你更好地掌握这些概念,并在以后编写自己的代码时感到更自在的原因。从头开始构建自己的Transformer的另一个原因是,它将让你完全理解如何使用上述API。如果我们查看Pytorch中Transformer类的forward()方法实现,你会看到很多晦涩的关键词,比如:

        如果你已经熟悉这些关键词,那么你可以愉快地跳过这篇文章。否则,本文将带你逐一理解这些关键词及其背后的概念。

Transformer简短介绍

        如果你听说过ChatGPT或Gemini,那么你已经遇到过Transformer了。事实上,ChatGPT中的“T”代表Transformer。该架构首次提出是在2017年,由谷歌研究人员在论文《Attention is All You Need》中提出。这是一种非常革命性的架构,因为以前的模型(用于序列到序列学习的机器翻译、语音转文本等)依赖于计算上昂贵的RNN,它们必须逐步处理序列。而Transformer只需要一次性地查看整个序列,将时间复杂度(Sequential Operations)从 O(n) 降低到 O(1)。

        对于其他的复杂度指标,我们先在此做最简单的理解,相信你在看完本篇文章后,将会有更深刻的认识。Complexity per Layer代表每层(主要)的计算复杂度,对于自注意力(Self-Attention),假设输入序列的长度为n,特征维度为d,首先会生成3个(Q, K, V)长度为n,特征维度为d的嵌入矩阵,Q与K(的转置)做矩阵乘法得到n x n维度的注意力图,代表每个位置与其他位置的注意力权重,该步骤复杂度为O(n^2 d),然后该注意力图与V做矩阵乘法得到更新后的n*d维度的输出序列特征表示,当然还有后面的线性变换层等操作,但主要的复杂度来自注意力图的计算。Sequential Operations代表每层需要依次进行的操作数,顺序操作数为 O(1),意味着这些操作可以并行化;顺序操作数为 O(n),代表需要依次进行n次操作。Maximum Path Length表示信息在层中传播所需经过的最大路径长度,Self-Attention只需操作一次就可以全局更新每个位置上的特征,而卷积需要log_k (n)次操作(受限于卷积窗口的大小)才可以更新每个位置,尽管这些卷积操作可以并行计算。

        通过这个表格可以看出,不同类型的层在计算复杂度、并行化能力和信息传播速度上的区别。自注意力层(Self-Attention)在并行化和信息传播速度上有明显优势,但计算复杂度较高。而循环层(Recurrent)虽然计算复杂度较低,但顺序操作数和最大路径长度较长,限制了并行化能力和信息传播速度。卷积层(Convolutional)在这三者之间,综合了计算复杂度和并行化能力。

Transformer的总体结构如下:

        如果你每次看到这个图都不知所云,我可以先告诉你几点:

        第一:在很多时候,包括你看到的很多文章里的模型,只用到了上图的编码器部分,用于提取输入序列的特征(如Vision Transformer, VIT)。而解码器主要用于构建完整的生成模型,用于像语言翻译、文字续写之类的生成任务。如果你的目的是分类这样的判别式任务,你大概率不会接触到Transformer解码器。

        第二:假设你的任务是一个文字翻译任务,当你在训练网络的时候,你的input会是一个经编码后的序列,假设没编码前的序列是“我很帅”,那么,你的解码器的输入会是GT(“I am handsome”),然后一次性预测GT中每个位置的词,当然这里会用到Masked Attention来保证解码器在预测“I”时不会看到“ am handsome”。这在训练时是并行执行的,与推理时的自回归方式不同。

        你需要明白,当我们说Transformer的时候,就是在说上面这个结构,他包含一个编码器和一个解码器,而不仅仅是编码器,更不是在说注意力机制。这点很重要,特别是当你来自计算机视觉领域,而不是自然语言处理。(手动狗头)

        好,现在让我们进入正题。

多头自注意力

        我们将实现的第一个块实际上是Transformer中最重要的部分,它被称为多头注意。让我们看看它在整个体系结构中的位置

        注意是一种机制,实际上并不是Transformer所特有的,它已经在RNN sequence-to-sequence模型中使用。

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=4):
        """
        input_dim: Dimensionality of the input.
        num_heads: The number of attention heads to split the input into.
        """
        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads"
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
        self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output layer
        
        
    def check_sdpa_inputs(self, x):
        assert x.size(1) == self.num_heads, f"Expected size of x to be ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads}), got {x.size()}"
        assert x.size(3) == self.hidden_dim // self.num_heads
        
        
    def scaled_dot_product_attention(self, query, key, value, 
            attention_mask=None, key_padding_mask=None):
        """
        query : (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)
        key : (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
        value :  (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
        attention_mask : (query_sequence_length, key_sequence_length)
        key_padding_mask : (sequence_length, key_sequence_length)
    
        """
        self.check_sdpa_inputs(query)
        self.check_sdpa_inputs(key)
        self.check_sdpa_inputs(value)
        
        
        d_k = query.size(-1)
        tgt_len, src_len = query.size(-2), key.size(-2)

        
        # logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 
        
        # Attention mask here
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                assert attention_ma
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值