def precompute_freqs_cis(
hidden_size: int,
max_seq_len: int,
base: int = 10000,
num_attention_heads: int = 8
) -> torch.Tensor:
"""
预计算复数形式的旋转位置编码 (Rotary Position Embedding, RoPE)
🔧 核心思想:
- 每个 token 的表示在每个维度上被看作一个二维向量 (x, y)
- 位置 i 对应的向量会被旋转 i×θ_d 弧度
- 使用复数乘法来高效实现这种旋转:e^(iθ) * (x + iy)
📚 所属知识领域:
- 位置编码设计(Positional Encoding)
- 复数数学在深度学习中的应用
- 因果语言建模的位置感知机制
Args:
hidden_size: 模型隐藏层维度(如 512)
max_seq_len: 最大支持序列长度(如 8192)
base: 控制频率衰减速度的基数,默认为 10000
num_attention_heads: 注意力头数(用于确定 head_dim)
Returns:
freqs_cis: complex tensor of shape (max_seq_len, head_dim)
表示每个位置对应的复数旋转因子 e^(iθ)
"""
head_dim = hidden_size // num_attention_heads # 每个注意力头的维度
# 计算逆频率:θ_d = 1 / (base^(2d/head_dim)) → 控制不同维度有不同的周期性
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len) # 位置索引 [0, 1, ..., S-1]
# 构造每个位置和每个频率的乘积:freqs[s,d] = t[s] * inv_freq[d]
freqs = torch.einsum("s,d->sd", t, inv_freq) # (S, D//2)
# 将频率复制两次,扩展到完整 head_dim 维度
# 原因:我们对相邻两维做联合旋转(如 x₀,x₁ → 旋转成新坐标),所以需要两个相同频率
freqs = torch.cat([freqs, freqs], dim=-1) # now (S, D)
# 转换为复数形式:cos(θ) + i*sin(θ),即单位圆上的点
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # magnitude=1, angle=freqs
return freqs_cis # shape: (max_seq_len, head_dim)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
在最后一个维度上,将后半部分移到前面并取负。
示例:
x = [x0, x1, x2, x3] → [-x2, -x3, x0, x1]
这对应于复数中乘以 i 的操作(相当于旋转90度):
i*(x0 + ix1) = -x1 + ix0 → 交换并加负号
📌 注意:这是 apply_rotary_pos_emb 中的关键辅助函数
Args:
x: 输入张量,shape = (*, D)
Returns:
旋转后的张量,shape = (*, D)
"""
x1, x2 = x.chunk(2, dim=-1) # 分成前后两半:x1 是前半,x2 是后半
return torch.cat((-x2, x1), dim=-1) # 后半取负放前,形成 [-x2, x1]
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
将 RoPE 应用于 query 和 key 张量。
数学公式:
q_embed = q * cos(θ) + rotate_half(q) * sin(θ)
k_embed = k * cos(θ) + rotate_half(k) * sin(θ)
📚 理论来源:
- RoFormer: https://arxiv.org/abs/2104.09864
- LLaMA 使用此方法替代绝对位置嵌入
Args:
q: Query 张量,shape = (bsz, n_heads, seq_len, head_dim)
k: Key 张量,shape = (bsz, n_heads, seq_len, head_dim)
freqs_cis: 预计算的复数旋转因子,shape = (seq_len, head_dim)
Returns:
q_embed: 加入位置信息的 query
k_embed: 加入位置信息的 key
"""
# 提取实部(cos)和虚部(sin),并调整形状以便广播
cos = freqs_cis.real.view(1, 1, freqs_cis.size(0), -1) # (1,1,S,D)
sin = freqs_cis.imag.view(1, 1, freqs_cis.size(0), -1) # (1,1,S,D)
# 截断到当前序列长度(防止越界)
q_len = q.size(2)
k_len = k.size(2)
# 应用旋转公式
q_out = (q * cos[:, :, :q_len, :]) + (_rotate_half(q) * sin[:, :, :q_len, :])
k_out = (k * cos[:, :, :k_len, :]) + (_rotate_half(k) * sin[:, :, :k_len, :])
return q_out.type_as(q), k_out.type_as(k)
检查一下公式和计算
最新发布