【手撕ROPE】极简方法实现旋转位置编码ROPE,面试遇到手撕ROPE再也不怕了!

旋转位置编码(Rotary Position Embedding, RoPE)详解与代码实现

1. 旋转位置编码(RoPE)简介

1.1 位置编码的作用

在 Transformer 结构中,自注意力机制(Self-Attention)是无序的,也就是说,它不会像 RNN 那样天然地利用序列的顺序信息。因此,我们需要 位置编码(Positional Encoding) 来提供位置信息。

传统的 Transformer 位置编码有:

  • 绝对位置编码(Absolute Positional Encoding):如 sin/cos 编码,编码是固定的,与输入无关。
  • 可学习位置编码(Learnable Positional Encoding):位置编码参数是可训练的。
  • 相对位置编码(Relative Position Encoding):考虑了两个 token 之间的相对位置关系。
  • 旋转位置编码(RoPE):通过复数旋转的方式,将相对位置信息嵌入到注意力机制中。

1.2 RoPE 的原理

RoPE 通过复数数乘的方式,让 token 在高维空间中进行旋转,从而编码相对位置信息。

对于输入向量 x = ( x 1 , x 2 , . . . , x d ) x = (x_1, x_2, ..., x_d) x=(x1,x2,...,xd)(其中 d d d 是维度),RoPE 将其拆分为偶数索引和奇数索引:
( x 1 , x 2 ) , ( x 3 , x 4 ) , . . . , ( x d − 1 , x d ) (x_1, x_2), (x_3, x_4), ..., (x_{d-1}, x_d) (x1,x2),(x3,x4),...,(xd1,xd)
然后,对这些二维向量对进行旋转:
( x ′ , y ′ ) = ( x cos ⁡ θ − y sin ⁡ θ , x sin ⁡ θ + y cos ⁡ θ ) (x', y') = (x \cos\theta - y \sin\theta, x \sin\theta + y \cos\theta) (x,y)=(xcosθysinθ,xsinθ+ycosθ)
其中 θ \theta θ 由位置 p p p 和固定基数 10000 10000 10000 计算得到:
θ p = 1 1000 0 2 i / d \theta_p = \frac{1}{10000^{2i/d}} θp=100002i/d1
其中 i i i 是当前维度索引, d d d 是头部维度(head_dim)。


2. 代码解析

下面是一个极其简单的 RoPE 实现:

import torch

def rope(x):
    batch_size, seq_len, head_dim = x.shape
    device = x.device

    # 计算 theta,生成不同维度的频率
    theta = 10000 ** (-torch.arange(0, head_dim, 2, device=device) / head_dim)
    # 生成位置索引
    pos = torch.arange(seq_len, device=device).unsqueeze(1)
    # 计算旋转角度
    angles = pos * theta

    # 计算 cos 和 sin
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    # 拆分输入数据,每两个维度为一组
    x1, x2 = x[..., 0::2], x[..., 1::2]
    
    # 应用旋转变换
    rotated_x1 = x1 * cos - x2 * sin
    rotated_x2 = x1 * sin + x2 * cos

    # 拼接回去,形成完整的旋转后向量
    return torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)

# 测试代码
batch_size, seq_len, head_dim = 2, 5, 4
x = torch.randn(batch_size, seq_len, head_dim)
x_rope = rope(x)

print("原始x:", x)
print("旋转后x:", x_rope)

3. 代码详解

3.1 生成 theta

theta = 10000 ** (-torch.arange(0, head_dim, 2, device=device) / head_dim)
  • torch.arange(0, head_dim, 2, device=device) / head_dim 生成一组频率索引。
  • 10000 ** (-...) 计算不同维度的旋转角频率,模拟不同层级的振荡。

3.2 计算 angles

pos = torch.arange(seq_len, device=device).unsqueeze(1)
angles = pos * theta
  • pos 是位置索引,如 [0, 1, 2, 3, 4](如果 seq_len=5)。
  • theta 乘以 pos,得到每个 token 在不同维度的旋转角度。

3.3 计算 cossin

cos = torch.cos(angles)
sin = torch.sin(angles)
  • 计算每个位置 p 和维度 i 上的旋转角 cos(θ_p)sin(θ_p)

3.4 拆分输入向量

x1, x2 = x[..., 0::2], x[..., 1::2]
  • 选取偶数索引 x1 和奇数索引 x2,两两成对形成二维向量对。

3.5 旋转变换

rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
  • 使用复数旋转公式,分别计算变换后的 x1'x2'

3.6 拼接回去

return torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
  • stack(堆叠)回来,变成 [batch, seq_len, head_dim//2, 2]
  • flatten(-2) 展平,恢复原始 head_dim 形状。

4. 结论

RoPE 通过复数旋转方式编码相对位置信息,不需要额外参数。
代码高效简洁,只需要 cossin 计算即可完成旋转变换。
适用于 Transformer,可以用于 querykey,增强注意力机制对相对位置信息的建模。

💡 这个实现是最基础的 RoPE,如果用于实际 Transformer 结构,建议优化成带有缓存 cos/sin 的版本(如 Qwen2RotaryEmbedding),以提高推理速度。 🚀

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值