传统Transformer的自注意力机制中有个关键步骤:计算查询(Query)和键(Key)之间的相似度,然后通过softmax函数转换成注意力权重。这个过程会产生一个N×N的矩阵(N是序列长度),导致计算和内存消耗随着序列长度的平方增长,也就是O(N²)。当处理长文本时,比如几千个token,这会非常吃内存,训练和推理速度也会变慢。
那么,如何解决这个问题呢?最近的研究提出用核函数(kernel function)代替softmax,这样可以避免显式计算整个注意力矩阵,从而将复杂度从O(N²)降到O(N)。但具体是怎么做到的?核函数在这里起到了什么作用?
首先,回忆一下传统的注意力计算步骤。假设我们有查询Q、键K和值V,注意力输出是:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
/
s
q
r
t
(
d
)
)
V
Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V
Attention(Q,K,V)=softmax(QKT/sqrt(d))V
这里的QK^T 是计算每个查询和键的点积,形成一个N×N的矩阵,然后经过softmax归一化得到权重,再与V相乘。问题就在于这个QK^T矩阵,当N很大时,存储和计算它都非常昂贵。
核函数的思想是,找到一个函数φ,将Q和K映射到另一个空间,使得点积QK^T可以近似表示为φ(Q)和φ(K)的乘积。这样的话,注意力计算可以重写为:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
≈
φ
(
Q
)
(
φ
(
K
)
T
V
)
Attention(Q, K, V) ≈ φ(Q) (φ(K)^T V)
Attention(Q,K,V)≈φ(Q)(φ(K)TV)
这样,不需要显式计算N×N的矩阵,而是先计算φ(K)^T V(这是一个d×d的矩阵,d是特征维度),然后与φ(Q)相乘。这一步将计算复杂度从O(N²d)降到了O(Nd²),因为φ(K)^T V的计算是O(Nd²),而φ(Q)与之相乘是O(Nd²)。对于长序列来说,d通常远小于N,所以这样的方法显著节省了内存和计算时间。
但这里有几个关键点需要解释清楚:
-
核函数的选择:什么样的φ函数可以近似softmax的效果?通常,核函数需要满足某种分解性质,比如随机傅里叶特征(Random Fourier Features)或者多项式展开,这样点积的指数函数(在softmax中)可以被分解成两个向量的乘积。
-
近似带来的影响:使用核函数代替softmax是否会影响模型的表现?理论上,如果核函数能够很好地近似softmax的分布,那么模型效果应该接近。但实际中可能需要调整或设计更合适的核函数,或者在训练时进行微调。
-
具体实现方式:例如,Performer模型使用了基于正交随机特征的核(FAVOR+),将Q和K映射到高维空间,从而近似softmax的结果,而无需计算完整的注意力矩阵。
举个例子,假设我们有一个核函数φ,可以将每个查询和键映射到一个新的空间,使得 e x p ( q i ⋅ k j ) ≈ φ ( q i ) ⋅ φ ( k j ) exp(q_i·k_j) ≈ φ(q_i)·φ(k_j) exp(qi⋅kj)≈φ(qi)⋅φ(kj)。这样,原本的softmax注意力权重矩阵中的每个元素可以近似为两个映射后的向量的点积。然后,整个注意力计算可以通过先计算φ(K)^T V,再与φ(Q)相乘来得到,避免了显式的N×N矩阵。
这样做的好处是显而易见的:内存占用大幅减少,计算速度更快,尤其在处理长序列时效果显著。但可能的缺点是,核函数的近似可能会引入误差,导致模型性能略有下降,或者需要更复杂的训练策略来适应这种近似。
此外,不同的线性Transformer变体可能采用不同的核函数或近似方法。例如,Linformer通过低秩投影来近似注意力矩阵,而Performer使用随机特征映射。这些方法的核心思想都是找到一种更高效的方式来代替原始的高复杂度计算,同时尽量保留注意力机制的关键特性。
总结一下,用核函数代替softmax的本质是通过数学上的技巧,将原本需要显式计算的大矩阵分解或近似为更低维度的操作,从而在保持模型表达能力的同时,显著降低计算和内存开销。这对于扩展Transformer到更长序列的任务(如文档理解、基因组分析等)非常重要。
在注意力机制中使用核函数替代 softmax,核心目的是避免计算庞大的注意力矩阵,从而显著降低计算和内存开销。我们可以通过以下类比和步骤来理解这一技术:
通俗解释:从“全员投票”到“代表选举”
想象一个班级要评选优秀学生,传统方式是让每个学生(Query)给全班同学(Key)打分(计算所有相似度),最终汇总成庞大的评分表(N×N 的注意力矩阵),再根据评分分配奖励(Value)。这种方式效率低,尤其当班级人数(序列长度 N)极大时。
核函数的作用相当于引入“学生代表”:
- 选举代表:每个学生用一套规则(核函数 φ)选出一小群代表,代表们浓缩了全班的信息。
- 代表投票:代表之间快速汇总意见,再反馈给每个学生。
- 结果分配:学生根据代表反馈的结果领取奖励。
这种方法避免了全员投票的复杂性,只需处理代表们的精简信息,效率大幅提升。
技术细节拆解
1. 传统注意力机制的问题
- 步骤:
计算每个 Query 与所有 Key 的相似度 → Softmax 归一化 → 加权求和 Value。 - 复杂度:
生成 N×N 的相似度矩阵,消耗 O(N²) 内存和计算时间。
2. 核函数替代 Softmax 的改进
-
核心思想:
利用数学技巧(核函数 φ),将相似度计算分解为两步,避免显式生成 N×N 矩阵。 -
具体步骤:
- Step 1:映射到核空间
用核函数 φ 将每个 Query 和 Key 映射到新的空间,使得相似度计算可表示为:
相似度 ( Q i , K j ) ≈ ϕ ( Q i ) ⋅ ϕ ( K j ) \text{相似度}(Q_i, K_j) \approx \phi(Q_i) \cdot \phi(K_j) 相似度(Qi,Kj)≈ϕ(Qi)⋅ϕ(Kj) - Step 2:分解计算
将注意力输出改写为:
注意力输出 = ϕ ( Q ) ⋅ ( ϕ ( K ) T ⋅ V ) \text{注意力输出} = \phi(Q) \cdot \left( \phi(K)^T \cdot V \right) 注意力输出=ϕ(Q)⋅(ϕ(K)T⋅V)
这里先计算 ϕ ( K ) T ⋅ V \phi(K)^T \cdot V ϕ(K)T⋅V(复杂度 O(Nd²)),再与 ϕ ( Q ) \phi(Q) ϕ(Q)相乘(复杂度 O(Nd²)),总复杂度从 O(N²d) 降为 O(Nd²)。
- Step 1:映射到核空间
3. 核函数的选择
- 理想要求:
核函数需保证 ϕ ( Q ) ϕ ( K ) T \phi(Q)\phi(K)^T ϕ(Q)ϕ(K)T 近似于原始注意力矩阵(含 Softmax)。 - 常见方法:
- 随机傅里叶特征(RFF):模拟高斯核,近似指数函数(Softmax 的核心)。
- 正交随机投影:提升近似稳定性(如 Performer 模型的 FAVOR+ 算法)。
- 多项式核:展开为低秩矩阵乘积。
实例说明(Performer 模型)
-
映射函数 φ:
使用随机矩阵将 Q/K 投影到高维空间,近似 Softmax 中的指数计算。- 例如: ϕ ( q ) = 1 m [ cos ( q W 1 ) , sin ( q W 1 ) , . . . , cos ( q W m ) , sin ( q W m ) ] \phi(q) = \frac{1}{\sqrt{m}}[\cos(qW_1), \sin(qW_1), ..., \cos(qW_m), \sin(qW_m)] ϕ(q)=m1[cos(qW1),sin(qW1),...,cos(qWm),sin(qWm)],其中 (W) 是随机矩阵。
-
计算流程:
- 先计算 ϕ ( K ) T V \phi(K)^T V ϕ(K)TV(聚合所有 Key-Value 对)。
- 再用 ϕ ( Q ) \phi(Q) ϕ(Q) 与之相乘,得到每个 Query 的注意力结果。
- 效果:处理 1 万 token 的序列时,内存消耗从 100GB 降至 1GB 以下。
优势与权衡
方面 | 传统注意力(Softmax) | 核函数注意力 |
---|---|---|
计算复杂度 | O(N²d)(如 2048²×64) | O(Nd²)(如 2048×64²) |
内存占用 | 存储 N×N 矩阵(如 4M 条目) | 仅存储映射后的低维矩阵 |
近似误差 | 精确计算 | 需设计高质量核函数减少误差 |
适用场景 | 短序列任务(如翻译) | 长序列任务(如基因组分析) |
总结
- 核函数的作用:通过数学变换,将昂贵的全矩阵计算分解为低维操作,实现线性复杂度。
- 实际意义:使 Transformer 能够处理超长文本(如整本书、DNA 序列),同时保持训练和推理效率。
- 类比记忆:
- 传统 Softmax → 全员投票,效率低但精确。
- 核函数 → 代表选举,效率高且近似全局结果。