旋转位置编码(Rotary Position Embedding, RoPE)详解与代码实现
1. 旋转位置编码(RoPE)简介
1.1 位置编码的作用
在 Transformer 结构中,自注意力机制(Self-Attention)是无序的,也就是说,它不会像 RNN 那样天然地利用序列的顺序信息。因此,我们需要 位置编码(Positional Encoding) 来提供位置信息。
传统的 Transformer 位置编码有:
- 绝对位置编码(Absolute Positional Encoding):如
sin/cos
编码,编码是固定的,与输入无关。 - 可学习位置编码(Learnable Positional Encoding):位置编码参数是可训练的。
- 相对位置编码(Relative Position Encoding):考虑了两个 token 之间的相对位置关系。
- 旋转位置编码(RoPE):通过复数旋转的方式,将相对位置信息嵌入到注意力机制中。
1.2 RoPE 的原理
RoPE 通过复数数乘的方式,让 token 在高维空间中进行旋转,从而编码相对位置信息。
对于输入向量
x
=
(
x
1
,
x
2
,
.
.
.
,
x
d
)
x = (x_1, x_2, ..., x_d)
x=(x1,x2,...,xd)(其中
d
d
d 是维度),RoPE 将其拆分为偶数索引和奇数索引:
(
x
1
,
x
2
)
,
(
x
3
,
x
4
)
,
.
.
.
,
(
x
d
−
1
,
x
d
)
(x_1, x_2), (x_3, x_4), ..., (x_{d-1}, x_d)
(x1,x2),(x3,x4),...,(xd−1,xd)
然后,对这些二维向量对进行旋转:
(
x
′
,
y
′
)
=
(
x
cos
θ
−
y
sin
θ
,
x
sin
θ
+
y
cos
θ
)
(x', y') = (x \cos\theta - y \sin\theta, x \sin\theta + y \cos\theta)
(x′,y′)=(xcosθ−ysinθ,xsinθ+ycosθ)
其中
θ
\theta
θ 由位置
p
p
p 和固定基数
10000
10000
10000 计算得到:
θ
p
=
1
1000
0
2
i
/
d
\theta_p = \frac{1}{10000^{2i/d}}
θp=100002i/d1
其中
i
i
i 是当前维度索引,
d
d
d 是头部维度(head_dim)。
2. 代码解析
下面是一个极其简单的 RoPE 实现:
import torch
def rope(x):
batch_size, seq_len, head_dim = x.shape
device = x.device
# 计算 theta,生成不同维度的频率
theta = 10000 ** (-torch.arange(0, head_dim, 2, device=device) / head_dim)
# 生成位置索引
pos = torch.arange(seq_len, device=device).unsqueeze(1)
# 计算旋转角度
angles = pos * theta
# 计算 cos 和 sin
cos = torch.cos(angles)
sin = torch.sin(angles)
# 拆分输入数据,每两个维度为一组
x1, x2 = x[..., 0::2], x[..., 1::2]
# 应用旋转变换
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
# 拼接回去,形成完整的旋转后向量
return torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
# 测试代码
batch_size, seq_len, head_dim = 2, 5, 4
x = torch.randn(batch_size, seq_len, head_dim)
x_rope = rope(x)
print("原始x:", x)
print("旋转后x:", x_rope)
3. 代码详解
3.1 生成 theta
theta = 10000 ** (-torch.arange(0, head_dim, 2, device=device) / head_dim)
torch.arange(0, head_dim, 2, device=device) / head_dim
生成一组频率索引。10000 ** (-...)
计算不同维度的旋转角频率,模拟不同层级的振荡。
3.2 计算 angles
pos = torch.arange(seq_len, device=device).unsqueeze(1)
angles = pos * theta
pos
是位置索引,如[0, 1, 2, 3, 4]
(如果seq_len=5
)。theta
乘以pos
,得到每个 token 在不同维度的旋转角度。
3.3 计算 cos
和 sin
cos = torch.cos(angles)
sin = torch.sin(angles)
- 计算每个位置
p
和维度i
上的旋转角cos(θ_p)
和sin(θ_p)
。
3.4 拆分输入向量
x1, x2 = x[..., 0::2], x[..., 1::2]
- 选取偶数索引
x1
和奇数索引x2
,两两成对形成二维向量对。
3.5 旋转变换
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
- 使用复数旋转公式,分别计算变换后的
x1'
和x2'
。
3.6 拼接回去
return torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
- 先
stack
(堆叠)回来,变成[batch, seq_len, head_dim//2, 2]
。 - 再
flatten(-2)
展平,恢复原始head_dim
形状。
4. 结论
✅ RoPE 通过复数旋转方式编码相对位置信息,不需要额外参数。
✅ 代码高效简洁,只需要 cos
、sin
计算即可完成旋转变换。
✅ 适用于 Transformer,可以用于 query
和 key
,增强注意力机制对相对位置信息的建模。
💡 这个实现是最基础的 RoPE,如果用于实际 Transformer 结构,建议优化成带有缓存 cos/sin 的版本(如 Qwen2RotaryEmbedding),以提高推理速度。 🚀