该问题归类到Transformer架构问题集——注意力机制——跨模态与多模态。请参考LLM数学推导——Transformer架构问题集。
1. 问题背景:长序列建模的复杂度困境
传统 Transformer 的自注意力机制时间复杂度为 (n 为序列长度),当处理文档级长文本(如
甚至更长)时,计算量会爆炸式增长(如
百万次操作)。Longformer 提出的 ** 滑动窗口注意力(Sliding Window Attention)** 通过限制每个位置的注意力范围,将复杂度降至
(k 为窗口大小),使其能高效处理长序列。
2. 滑动窗口注意力的核心机制
核心思想:每个位置仅关注其左右各 k 个邻居,形成大小为 的局部窗口(边界位置窗口大小递减)。
- 非重叠窗口:早期实现中窗口不重叠,但会导致上下文断裂;
- 滑动窗口(重叠窗口):窗口每次滑动 s 步(通常
或 k),确保上下文连续性。 图示:对于序列
,若
,窗口依次为
、
、
、
,每个位置被
个窗口覆盖。
3. 时间复杂度推导
3.1 单头滑动窗口注意力
设序列长度为 n,窗口大小为 (双边窗口,左右各 k 个位置)。
- 查询 - 键 - 值计算: 每个位置的查询
与窗口内 w 个键
、值
计算注意力:
单次注意力计算复杂度为
(
为键向量维度)。
- 总计算量: 每个位置参与 1 次查询,共 n 个位置,总复杂度为:
当
时,复杂度从
降至线性级
。
3.2 多头滑动窗口注意力(Multi-Head Attention, MHA)
设头数为 h,每头维度为 ,则总复杂度为:
与传统 MHA 的
相比,长序列下优势显著(如
时,复杂度从
百万降至
百万)。
4. 空间复杂度推导
- 键值存储:每个窗口需存储 w 个键 K 和值 V,共 n 个位置,总存储空间为:
(
为值向量维度,通常
)。
- 注意力权重:每个窗口存储
的注意力矩阵,但实际中可通过稀疏存储优化,空间复杂度仍为
量级。
5. 边界处理与实际复杂度修正
- 边界窗口缩小:序列前 k 个位置和后 k 个位置的窗口大小不足 w,但当
时,边界效应可忽略,总复杂度仍近似为
。
- 滑动步长 s:若窗口每次滑动 s 步(如 s=k),则总窗口数为
,每个位置被
个窗口覆盖。但 Longformer 默认使用 s=1(完全重叠窗口),确保上下文连续性。
6. 与全局注意力的结合:Longformer 的混合复杂度
Longformer 为兼顾长距离依赖,引入全局注意力(少数关键位置参与全序列注意力),其复杂度为 (m 为全局位置数)。整体复杂度为:
当
时,仍以滑动窗口的
为主导。
在 LLM 中的应用:长文本场景的刚需
1. 文档级任务
- 机器翻译:处理长段落时,滑动窗口允许模型聚焦局部上下文(如句子内短语),避免全局计算的爆炸式开销。
- 问答系统:在数百页的文档中检索答案时,滑动窗口可逐段提取关键信息,结合全局注意力定位跨段落关联。
2. 代码预训练
代码通常具有长序列特性(如函数定义跨多行),滑动窗口注意力能高效捕捉代码结构的局部依赖(如变量声明与引用),同时通过全局注意力处理跨函数的调用关系。
3. 对话历史建模
在多轮对话中(如对话历史长度 n=1024),滑动窗口可聚焦最近的 2k+1 轮对话,避免传统 Transformer 对整个历史的 计算,提升实时交互效率。
代码示例:滑动窗口注意力的简化实现
import torch
from torch.nn import Module, Linear
import torch.nn.functional as F
class SlidingWindowAttention(Module):
def __init__(self, d_model, n_heads, window_size=128):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.window_size = window_size # 单边窗口大小,总窗口=2*window_size+1
self.qkv_proj = Linear(d_model, 3 * d_model) # 投影层
self.out_proj = Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape # B=批次, N=序列长度, D=d_model
qkv = self.qkv_proj(x).chunk(3, dim=-1) # 拆分为Q, K, V
q, k, v = [t.reshape(B, N, self.n_heads, D//self.n_heads).transpose(1, 2)
for t in qkv] # 形状变为 (B, h, N, d_head)
# 滑动窗口处理:为每个位置提取左右window_size个邻居
# 简化实现,实际需处理边界(此处假设N > 2*self.window_size)
padded = F.pad(x, (0, 0, self.window_size, self.window_size)) # 左右各填充window_size
indices = torch.arange(-self.window_size, self.window_size+1, device=x.device)
k_window = padded[:, N+indices, :] # 提取每个位置的窗口键值(简化索引)
v_window = padded[:, N+indices, :]
# 计算注意力分数:(B, h, N, window_size)
attn_scores = (q @ k_window.transpose(-2, -1)) / (D//self.n_heads)**0.5
attn_weights = F.softmax(attn_scores, dim=-1)
out = attn_weights @ v_window # (B, h, N, d_head)
out = out.transpose(1, 2).reshape(B, N, D)
return self.out_proj(out)
代码解析:
- 窗口填充:通过
F.pad
在序列左右添加window_size
个 padding,便于统一索引; - 索引提取:用
indices
模拟滑动窗口的相对位置,提取每个查询对应的键值对; - 复杂度控制:每个查询仅与 2k+1个键值对计算,避免全序列遍历。 注意:实际 Longformer 实现会更复杂,需处理边界索引、缓存历史窗口等优化。
总结:滑动窗口的 “局部 - 全局” 智慧
滑动窗口注意力通过将长序列拆解为局部窗口,以 O(nk) 复杂度突破传统 Transformer 的 瓶颈,成为 LLM 处理文档级长文本的核心技术。其核心思想在于:用局部上下文的高效计算,结合全局注意力的长距离建模,在复杂度与表达能力间找到平衡。这一设计不仅适用于文本领域,也为图像分块、视频时序建模等长序列问题提供了通用思路。