一、概念
旋转位置编码(Rotary Position Embedding,RoPE)是一种用于大模型(如Transformer)的位置编码方法。它通过旋转embedding向量来引入位置信息,从而在模型中捕捉序列的相对位置信息。RoPE在一些自然语言处理任务中表现出色,尤其是在处理长序列时。
在Transformer模型中,位置编码(Position Embedding)用于引入序列中每个位置的位置信息,因为Transformer的自注意力机制本身是无序的。传统的绝对位置编码(如正弦和余弦位置编码)在处理长序列时可能会遇到一些问题,如无法有效捕捉相对位置信息。目前,诸如Llama、Qwen之类的大模型均在使用旋转位置嵌入方法,可见其重要性。
二、原理
旋转位置编码的核心思想是通过旋转embedding向量来引入位置信息。
1、旋转变换
假设embedding向量为,其维度为 d。RoPE将embedding向量中的特征分成两两一组,共d/2对,每组表示为(
)进行旋转变换。对于每个位置 i,定义一个角度
,通过下面的变换对每对二维向量进行旋转:
其中,是与位置 i 相关的旋转角度,通常是线性增长的。
2、旋转角度
旋转角度通常与位置 i 和embedding维度 d 相关,可以定义为:
这种定义方式类似于正弦和余弦位置编码,但通过旋转变换引入了相对位置信息。
三、python实现
import torch
import torch.nn as nn
import math
# 生成旋转矩阵
def precompute_freqs_cis(dim, seq_len, theta=10000.0):
# 计算词向量元素两两分组之后,每组元素对应的旋转角度θᵢ
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
# 生成 token 序列索引 t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.ger(t, freqs).float() # 计算m * θ
# 计算结果是个复数向量
# 假设 freqs = [x, y]
# 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
# 旋转位置编码计算
def apply_rotary_emb(xq, xk, freqs_cis):
# xq.shape = [batch_size, seq_len, dim]
# xq_.shape = [batch_size, seq_len, dim // 2, 2]
xq_ = xq.float().view(xq.size(0), xq.size(1), -1, 2)
xk_ = xk.float().view(xk.size(0), xk.size(1), -1, 2)
# 转为复数域
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)
# 应用旋转操作,然后将结果转回实数域
# xq_out.shape = [batch_size, seq_len, dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).view(xq.size(0), xq.size(1), -1)
xk_out = torch.view_as_real(xk_ * freqs_cis).view(xk.size(0), xk.size(1), -1)
return xq_out.type_as(xq), xk_out.type_as(xk)
# 示例
batch_size, seq_len, dim = 32, 128, 512
xq = torch.randn(batch_size, seq_len, dim)
xk = torch.randn(batch_size, seq_len, dim)
freqs_cis = precompute_freqs_cis(dim, seq_len)
xq_rotated, xk_rotated = apply_rotary_emb(xq, xk, freqs_cis)
print(xq_rotated.size()) # 输出: torch.Size([32, 128, 512])
print(xk_rotated.size()) # 输出: torch.Size([32, 128, 512])