1. 什么是旋转位置编码?
旋转位置编码是一种将 位置信息 融入 词向量 的方法。它的核心思想是通过 旋转 的方式,将词向量的每一维根据其位置进行变换,从而让模型能够感知到词的位置信息。
2. 为什么要用旋转位置编码?
在 Transformer 模型中,词与词之间的 位置关系 非常重要。传统的 Transformer 使用 绝对位置编码(比如 Sinusoidal 位置编码),但它可能无法很好地捕捉 相对位置 信息。旋转位置编码通过 旋转 的方式,能够更好地建模词与词之间的相对位置关系。
3. 旋转位置编码的核心思想
旋转位置编码的核心思想是:通过旋转矩阵对词向量进行变换,旋转的角度与词的位置相关。
举个例子:
假设我们有一个二维的词向量 x = [ x 1 , x 2 ] x = [x_1, x_2] x=[x1,x2] ,我们可以通过旋转矩阵对它进行变换:
[ x 1 ′ x 2 ′ ] = [ cos θ − sin θ sin θ cos θ ] [ x 1 x 2 ] \begin{bmatrix} x'_1 \\ x'_2 \end{bmatrix} = \begin{bmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} [x1′x2′]=[cosθsinθ−sinθcosθ][x1x2]
其中, \theta 是与位置相关的角度。
通过这种方式,词向量的每一维都会根据位置进行旋转,从而融入位置信息。
4. 旋转位置编码的具体步骤
(1)将词向量分成两部分
假设词向量的维度是 d d d ,我们将它分成两部分,每部分的维度是 d / 2 d/2 d/2 :
x = [ x 1 , x 2 , … , x d / 2 , x d / 2 + 1 , … , x d ] x = [x_1, x_2, \dots, x_{d/2}, x_{d/2+1}, \dots, x_d] x=[x1,x2,…,xd/2,xd/2+1,…,xd]
前 d / 2 d/2 d/2 维是 x ( 1 ) x^{(1)} x(1) ,后 d / 2 d/2 d/2 维是 x ( 2 ) x^{(2)} x(2) 。
(2)构造旋转矩阵
对于位置 m m m ,我们构造一个旋转矩阵 R m R_m Rm 。这个矩阵的作用是对词向量的每一对维度进行旋转。
旋转矩阵的形式如下:
R m = [ cos m θ 1 − sin m θ 1 sin m θ 1 cos m θ 1 ] ⊕ [ cos m θ 2 − sin m θ 2 sin m θ 2 cos m θ 2 ] ⊕ ⋯ ⊕ [ cos m θ d / 2 − sin m θ d / 2 sin m θ d / 2 cos m θ d / 2 ] R_m = \begin{bmatrix} \cos m\theta_1 & -\sin m\theta_1 \\ \sin m\theta_1 & \cos m\theta_1 \end{bmatrix} \oplus \begin{bmatrix} \cos m\theta_2 & -\sin m\theta_2 \\ \sin m\theta_2 & \cos m\theta_2 \end{bmatrix} \oplus \dots \oplus \begin{bmatrix} \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{bmatrix} Rm=[cosmθ1sinmθ1−sinmθ1cosmθ1]⊕[cosmθ2sinmθ2−sinmθ2cosmθ2]⊕⋯⊕[cosmθd/2sinmθd/2−sinmθd/2cosmθd/2]
其中:
- θ i = 1000 0 − 2 i / d \theta_i = 10000^{-2i/d} θi=10000−2i/d 是频率因子。
- ⊕ \oplus ⊕ 表示矩阵的直和(即将多个小矩阵拼接成一个大矩阵)。
(3)应用旋转矩阵
将旋转矩阵 R m R_m Rm 应用到词向量 x x x 上:
x ′ = R m ⋅ x x' = R_m \cdot x x′=Rm⋅x
具体来说,对于每一对 ( x i , x d / 2 + i ) (x_i, x_{d/2+i}) (xi,xd/2+i) ,旋转操作如下:
[ x i ′ x d / 2 + i ′ ] = [ cos m θ i − sin m θ i sin m θ i cos m θ i ] [ x i x d / 2 + i ] \begin{bmatrix} x'_i \\ x'_{d/2+i} \end{bmatrix} = \begin{bmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{bmatrix} \begin{bmatrix} x_i \\ x_{d/2+i} \end{bmatrix} [xi′xd/2+i′]=[cosmθisinmθi−sinmθicosmθi][xixd/2+i]
5. 代码实现
以下是旋转位置编码的 PyTorch 实现:
import torch
import torch.nn as nn
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
# 初始化频率因子 theta_i
theta = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('theta', theta)
def forward(self, x, seq_len):
"""
x: (batch_size, seq_len, dim)
seq_len: 序列长度
"""
batch_size, seq_len, dim = x.shape
# 生成位置索引 m
m = torch.arange(seq_len, device=x.device).float()
# 构造旋转矩阵 R_m
freqs = torch.einsum('i,j->ij', m, self.theta) # (seq_len, dim/2)
freqs = torch.cat([freqs, freqs], dim=-1) # (seq_len, dim)
# 将 freqs 转换为 cos 和 sin
cos = torch.cos(freqs).unsqueeze(0) # (1, seq_len, dim)
sin = torch.sin(freqs).unsqueeze(0) # (1, seq_len, dim)
# 旋转操作
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
return x_rotated
6. 总结
- 旋转位置编码 通过旋转矩阵将位置信息融入词向量,能够更好地建模词与词之间的相对位置关系。
- 它的实现非常简单,只需要对词向量的每一对维度进行旋转即可。
- 这种方法在长序列任务中表现优异,且计算效率高。