Transformer Quality in Linear Time

TL;DR

2022 年谷歌发表的 transformer 结构优化论文,本文提出了两种创新方法——门控注意单元(GAU)和混合块注意力(Mixed Chunk Attention),共同构成FLASH模型。FLASH 在短序列和长序列任务中均能匹配 Transformer 的质量,同时显著提升训练速度,为高效长序列建模提供了新思路。

Paper name

Transformer Quality in Linear Time

Paper Reading Note

Paper URL:

  • https://arxiv.org/abs/2202.10447

Introduction

背景

  • Transformer 模型随输入长度呈二次复杂度增长。这一限制阻碍了 Transformer 处理长期信息,而长期信息对于许多应用来说至关重要。
  • 为了加速 Transformer 在长上下文中的计算效率,研究人员提出了多种更高效的注意力机制。现有高效注意力方法通常存在以下缺陷
    • 质量下降:相比于增强后的 Transformer,现有的高效注意力方法往往会导致显著的质量下降,而这种质量损失超过了它们在计算效率上的优势
    • 实际开销较大:许多高效注意力方法在 Transformer 层中引入了额外的复杂性,并需要大量的内存重排操作,导致它们的 理论复杂度 与 实际加速性能(在 GPU 或 TPU 上的表现) 之间存在较大差距。
    • 自回归训练效率低:大多数注意力线性化技术在推理过程中可以实现快速解码,但在 自回归任务(如语言建模) 训练时却极为缓慢。这主要是因为它们采用 类似 RNN 的逐步状态更新方式,需要在大量训练步骤中不断维护和更新状态,无法充分利用现代加速器的并行计算能力。

本文方案

  • 提出了以下两种模型单元和架构
    • 提出了一种名为门控注意单元(Gated Attention Unit, GAU)的简单层,它允许使用较弱的单头注意力机制,同时仅带来最小的质量损失
    • 提出了一种与该新层互补的线性近似方法,最终得到的模型被命名为 FLASH,在短(512)和长(8K)上下文长度上均能匹配改进版 Transformer 的困惑度(perplexity)
  • 在 Wiki-40B 数据集上实现了最高 4.9 倍的训练加速,在 PG-19 数据集上达到了 12.1 倍加速;在 C4 数据集上的掩码语言建模任务中,训练速度提升了 4.8 倍
    在这里插入图片描述

Methods

门控注意单元(Gated Attention Unit, GAU)

  • 门控注意单元(GAU) 是比 Transformer 更简单但更高效的计算层。虽然 GAU 仍然具有 随上下文长度呈二次复杂度的计算成本,但它更适合作为后续近似方法的基础。首先介绍相关层的背景:
标准 MLP

X ∈ R T × d X \in \mathbb{R}^{T \times d} XRT×d T T T 个 token 的表示,Transformer 中的 MLP 计算如下:

O = ϕ ( X W u ) W o O = \phi(X W_u) W_o O=ϕ(XWu)Wo

其中, W u ∈ R d × e W_u \in \mathbb{R}^{d \times e} WuRd×e W o ∈ R e × d W_o \in \mathbb{R}^{e \times d} WoRe×d d d d 为模型维度, e e e 为扩展的中间层维度, ϕ \phi ϕ 为逐元素的激活函数。

门控线性单元(Gated Linear Unit, GLU)

GLU 是在 MLP 的基础上加入门控机制的改进(Dauphin 等,2017)。GLU 在多个任务中被证明有效(Shazeer,2020;Narang 等,2021),并已应用于最先进的 Transformer 语言模型(Du 等,2021;Thoppilan 等,2022)。其计算如下:

U = ϕ u ( X W u ) , V = ϕ v ( X W v ) ∈ R T × e U = \phi_u(X W_u), \quad V = \phi_v(X W_v) \in \mathbb{R}^{T \times e} U=ϕu(XWu),V=ϕv(XWv)RT×e

O = ( U ⊙ V ) W o ∈ R T × d O = (U \odot V) W_o \in \mathbb{R}^{T \times d} O=(UV)WoRT×d

其中 ⊙ \odot 表示逐元素乘法。在 GLU 中,每个表示 u i u_i ui 由另一个与相同 token 关联的表示 v i v_i vi 进行门控。

门控注意单元(GAU)

在这里插入图片描述
GAU 的核心思想是 将注意力和 GLU 统一成一个计算层,并尽可能共享它们的计算(如图 2 所示)。这样不仅提高了 参数/计算效率,还能自然引入 强大的注意力门控机制。具体而言,GAU 推广 了 GLU 中的公式:

O = ( U ⊙ V ^ ) W o , V ^ = A V O = (U \odot \hat{V}) W_o, \quad \hat{V} = A V O=(UV^)Wo,V^=AV

其中, A ∈ R T × T A \in \mathbb{R}^{T \times T} ART×Ttoken 之间的注意力权重。与 GLU 始终使用 v i v_i vi 来门控 u i u_i ui(两者都与相同 token 相关)不同,GAU 用注意力计算出的 v ^ i = ∑ j a i j v j \hat{v}_i = \sum_j a_{ij} v_j v^i=jaijvj 代替 v i v_i vi,从所有可用的 token 中“检索”更相关的信息。当 A A A 为单位矩阵时,该公式将退化为 GLU。

与 Liu 等(2021)的研究结果一致,GAU 的门控机制允许使用比 MHSA 更简单、更弱 的注意力机制,而不会影响质量:

Z = ϕ z ( X W z ) ∈ R T × s Z = \phi_z (X W_z) \in \mathbb{R}^{T \times s} Z=ϕz(XWz)RT×s

A = relu 2 ( Q ( Z ) K ( Z ) ⊤ + b ) ∈ R T × T A = \text{relu}^2 (Q(Z) K(Z)^\top + b) \in \mathbb{R}^{T \times T} A=relu2(Q(Z)K(Z)+b)RT×T

其中, Z Z Z 是共享表示( s ≪ d s \ll d sd), Q Q Q K K K 是两个轻量级变换,对 Z Z Z 进行逐维缩放和平移(类似于 LayerNorm 的可学习参数), b b b 为相对位置偏置。此外,我们发现 在 GAU 中 softmax 可以被替换为更简单的激活函数

表 1 展示了不同修改对 GAU 质量的影响:

修改PPLX (LM/MLM)参数量 (M)
原始 GAU16.78 / 4.23105
relu² → softmax17.04 / 4.31105
单头 → 多头17.76 / 4.48105
无门控17.45 / 4.58131

此外,GAU 仅在 GLU 之上引入了一个小型的 稠密矩阵 W z W_z Wz(参数量 d s d s ds,相比于 Transformer 中 MHSA 的 4 d 2 4d^2 4d2 参数,其计算复杂度大幅降低。通过设定 e = 2 d e = 2d e=2d,可以 用两个 GAU 层替换每个 Transformer 块(MLP/GLU + MHSA),同时保持相似的模型大小和训练速度

基于 GAU 的快速线性注意力(FLASH)

前面的研究给出了两个启发,使我们能够扩展 GAU 以处理长序列:

  • GAU 的门控机制允许使用更弱(单头、无 softmax)的注意力,而不会损失质量。如果将这一思想扩展到长序列建模,GAU 也能增强近似(弱)注意力机制(如 局部注意力、稀疏注意力、线性注意力)的效果。在接下来的部分,我们将介绍如何利用这一特点,开发 真正高效的线性注意力模型 FLASH
  • 额外的注意力模块数量加倍的影响:此外,由于 MLP+MHSA≈2×GAU 在计算成本上是等价的,因此 GAU 天然具有两倍的注意力模块数量。由于近似注意力通常需要更多层才能捕获完整的依赖关系(Dai 等,2019;Child 等,2019),这一特性使 GAU 在处理长序列时更具吸引力。基于这一直觉,我们首先回顾了一些现有的 长序列建模方法,然后介绍 如何让 GAU 在长序列上以线性时间实现 Transformer 级别的质量

现有的线性复杂度变体

部分注意力(Partial Attention)

一种常见的方法是使用 局部/稀疏注意力 近似完整的注意力矩阵,这些方法包括:

  • 局部窗口注意力(Dai 等,2019;Rae 等,2019)
  • 局部 + 稀疏注意力(Child 等,2019;Li 等,2019;Beltagy 等,2020;Zaheer 等,2020)
  • 轴向注意力(Ho 等,2019;Huang 等,2019)
  • 基于哈希(Kitaev 等,2020)或聚类(Roy 等,2021)学习的模式

虽然这些方法不能完全匹配完整注意力的效果,但它们通常能够 通过扩展至更长的序列来获得质量提升。然而,这类方法的主要问题在于,它们 涉及大量的不规则或规则的内存重排操作(如 gatherscattersliceconcatenation),这些操作 不适用于现代加速器的高并行性计算,特别是 TPU 这类专用 ASIC 硬件。因此,它们的 理论效率和实际加速效果 之间通常存在较大差距。因此,在本研究中,我们 刻意减少了模型中的内存重排操作

线性注意力(Linear Attention)

另一种流行的方法是 通过分解注意力矩阵并重新排列矩阵乘法的顺序 来线性化计算(Choromanski 等,2020;Wang 等,2020;Katharopoulos 等,2020;Peng 等,2021)。其计算公式如下:

V ^ lin = Q ( K ⊤ V ) ≈ V ^ quad = Softmax ( Q K ⊤ ) V \hat{V}_{\text{lin}} = Q (K^\top V) \approx \hat{V}_{\text{quad}} = \text{Softmax} (Q K^\top) V V^lin=Q(KV)V^quad=Softmax(QK)V

其中, Q , K , V ∈ R T × d Q, K, V \in \mathbb{R}^{T \times d} Q,K,VRT×d 分别为 查询(query)、键(key)和值(value)。这种 重排计算顺序 使得复杂度从 O ( T 2 d ) O(T^2 d) O(T2d) 降至 O ( T d ) O(T d) O(Td)

线性注意力的 另一个优点 是,在推理过程中,每个 自回归解码步骤 的计算和内存需求是 常数

M t = M t − 1 + K t V t ⊤ M_t = M_{t-1} + K_t V_t^\top Mt=Mt1+KtVt

这意味着:

  • 只需维护一个 O ( d 2 ) O(d^2) O(d2) 大小的缓存
  • 每个新输入 t t t 仅需 O ( d 2 ) O(d^2) O(d2) 计算量来更新 M t M_t Mt

相比之下,完整的二次注意力每一步解码时都需要 O ( T d ) O(T d) O(Td) 计算和内存,因为每个新输入必须 与所有先前的输入进行注意力计算

然而,线性注意力在 自回归训练阶段存在严重的低效问题

  • 由于 自回归因果约束,每个时间步的查询 Q t Q_t Qt 需要不同的缓存值 M t = K : t ⊤ V : t M_t = K_{:t}^\top V_{:t} Mt=K:tV:t
  • 这意味着模型 需要存储 T T T 个不同的 M t M_t Mt,而非自回归情况下的 单个 K ⊤ V K^\top V KV
  • 在理论上,可以通过 累积求和(cumsum) 计算:

M t = ∑ i = 1 t K i V i ⊤ M_t = \sum_{i=1}^{t} K_i V_i^\top Mt=i=1tKiVi

但在实践中,这种 累积求和操作 在现代加速器上 引入了类似 RNN 的顺序依赖,导致:

  • 并行计算受限
  • 每个步骤都需要 T T T 次内存访问
  • 导致实际计算速度大幅下降

在 TPU 和 GPU 上,直接计算完整二次注意力甚至比线性注意力更快

混合块注意力(Mixed Chunk Attention)

基于现有 线性复杂度注意力方法 的优缺点,我们提出了一种新的 混合块注意力(Mixed Chunk Attention),结合了 部分注意力线性注意力 的优势,如图 4 所示。以下是具体方法。
在这里插入图片描述

预处理

首先,我们将输入序列 划分为 G G G 个不重叠的块(chunk),每个块大小为 C C C,即:

[ T ] → [ T / C × C ] [T] \rightarrow [T / C \times C] [T][T/C×C]

然后,对每个块 g g g,根据 GAU 公式(见式 (1) 和 (4)),计算:

U g ∈ R C × e , V g ∈ R C × e , Z g ∈ R C × s U_g \in \mathbb{R}^{C \times e}, \quad V_g \in \mathbb{R}^{C \times e}, \quad Z_g \in \mathbb{R}^{C \times s} UgRC×e,VgRC×e,ZgRC×s

接着,我们从 Z g Z_g Zg 生成四种注意力头:

Q g quad , K g quad , Q g lin , K g lin Q_g^{\text{quad}}, K_g^{\text{quad}}, Q_g^{\text{lin}}, K_g^{\text{lin}} Qgquad,Kgquad,Qglin,Kglin

这些变换都是 逐维缩放和平移(计算开销极低)。接下来,我们介绍 如何高效地近似 GAU 的注意力

局部块内注意力

首先,在 每个块内 独立应用 二次注意力

V ^ g quad = ReLU 2 ( Q g quad K g quad ⊤ + b ) V g \hat{V}_g^{\text{quad}} = \text{ReLU}^2 (Q_g^{\text{quad}} K_g^{\text{quad} \top} + b) V_g V^gquad=ReLU2(QgquadKgquad+b)Vg

其计算复杂度为:

O ( G × C 2 × d ) = O ( T C d ) O(G \times C^2 \times d) = O(T C d) O(G×C2×d)=O(TCd)

如果 C C C 为常数,则该部分计算复杂度对 T T T 是线性的

全局块间注意力

此外,我们使用 全局线性注意力 处理 块间的长程交互

  • 非因果(Non-Causal):

    V ^ lin = Q lin ( ∑ h = 1 G K h lin ⊤ V h ) \hat{V}^{\text{lin}} = Q^{\text{lin}} \left( \sum_{h=1}^{G} K_h^{\text{lin} \top} V_h \right) V^lin=Qlin(h=1GKhlinVh)

  • 因果(Causal):

    V ^ lin = Q lin ( ∑ h = 1 g − 1 K h lin ⊤ V h ) \hat{V}^{\text{lin}} = Q^{\text{lin}} \left( \sum_{h=1}^{g-1} K_h^{\text{lin} \top} V_h \right) V^lin=Qlin(h=1g1KhlinVh)

注意:上述求和运算是在 块级别 进行的,因此相比于 token 级别的线性注意力,累积求和的步数减少了 C C C 倍(通常设 C = 256 C=256 C=256,大幅加快训练速度。

最终, V ^ g quad \hat{V}_g^{\text{quad}} V^gquad V ^ lin \hat{V}^{\text{lin}} V^lin 相加,并经过门控和后处理:

O g = [ U g ⊙ ( V ^ g quad + V ^ lin ) ] W o O_g = \left[ U_g \odot (\hat{V}_g^{\text{quad}} + \hat{V}^{\text{lin}}) \right] W_o Og=[Ug(V^gquad+V^lin)]Wo


讨论:为什么混合块注意力更快?

  • 自回归训练加速:由于 块化操作,自回归训练中的顺序依赖 T T T 下降到 T / C T/C T/C,大幅提高了训练速度。
  • 非重叠块注意力:相比 重叠局部注意力(如 Longformer、BigBird),我们的方法减少了 内存重排操作,在 TPU 上表现更优。

Experiments

  • FLASH 在长序列上达到了 Transformer 级别的质量,同时实现了线性复杂度加速
    在这里插入图片描述

Conclusion

  • GAU 以简洁的结构融合门控与注意力,显著降低了参数量和计算成本,同时保持模型表现力。
  • 论文的研究思路很值得参考,从基于 GLU 的思考,想办法降低 attention 在模型中的重要性占比,然后用线性 attention 近似就不会造成过大的性能损失,整个逻辑很顺畅。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值