【LLM】旋转位置编码 RoPE

1. 什么是旋转位置编码?

旋转位置编码是一种将 位置信息 融入 词向量 的方法。它的核心思想是通过 旋转 的方式,将词向量的每一维根据其位置进行变换,从而让模型能够感知到词的位置信息。


2. 为什么要用旋转位置编码?

在 Transformer 模型中,词与词之间的 位置关系 非常重要。传统的 Transformer 使用 绝对位置编码(比如 Sinusoidal 位置编码),但它可能无法很好地捕捉 相对位置 信息。旋转位置编码通过 旋转 的方式,能够更好地建模词与词之间的相对位置关系。


3. 旋转位置编码的核心思想

旋转位置编码的核心思想是:通过旋转矩阵对词向量进行变换,旋转的角度与词的位置相关

举个例子:

假设我们有一个二维的词向量 x = [ x 1 , x 2 ] x = [x_1, x_2] x=[x1,x2] ,我们可以通过旋转矩阵对它进行变换:

[ x 1 ′ x 2 ′ ] = [ cos ⁡ θ − sin ⁡ θ sin ⁡ θ cos ⁡ θ ] [ x 1 x 2 ] \begin{bmatrix} x'_1 \\ x'_2 \end{bmatrix} = \begin{bmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} [x1x2]=[cosθsinθsinθcosθ][x1x2]

其中, \theta 是与位置相关的角度。

通过这种方式,词向量的每一维都会根据位置进行旋转,从而融入位置信息。


4. 旋转位置编码的具体步骤

(1)将词向量分成两部分

假设词向量的维度是 d d d ,我们将它分成两部分,每部分的维度是 d / 2 d/2 d/2

x = [ x 1 , x 2 , … , x d / 2 , x d / 2 + 1 , … , x d ] x = [x_1, x_2, \dots, x_{d/2}, x_{d/2+1}, \dots, x_d] x=[x1,x2,,xd/2,xd/2+1,,xd]

d / 2 d/2 d/2 维是 x ( 1 ) x^{(1)} x(1) ,后 d / 2 d/2 d/2 维是 x ( 2 ) x^{(2)} x(2)

(2)构造旋转矩阵

对于位置 m m m ,我们构造一个旋转矩阵 R m R_m Rm 。这个矩阵的作用是对词向量的每一对维度进行旋转。

旋转矩阵的形式如下:

R m = [ cos ⁡ m θ 1 − sin ⁡ m θ 1 sin ⁡ m θ 1 cos ⁡ m θ 1 ] ⊕ [ cos ⁡ m θ 2 − sin ⁡ m θ 2 sin ⁡ m θ 2 cos ⁡ m θ 2 ] ⊕ ⋯ ⊕ [ cos ⁡ m θ d / 2 − sin ⁡ m θ d / 2 sin ⁡ m θ d / 2 cos ⁡ m θ d / 2 ] R_m = \begin{bmatrix} \cos m\theta_1 & -\sin m\theta_1 \\ \sin m\theta_1 & \cos m\theta_1 \end{bmatrix} \oplus \begin{bmatrix} \cos m\theta_2 & -\sin m\theta_2 \\ \sin m\theta_2 & \cos m\theta_2 \end{bmatrix} \oplus \dots \oplus \begin{bmatrix} \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{bmatrix} Rm=[cosmθ1sinmθ1sinmθ1cosmθ1][cosmθ2sinmθ2sinmθ2cosmθ2][cosmθd/2sinmθd/2sinmθd/2cosmθd/2]

其中:

  • θ i = 1000 0 − 2 i / d \theta_i = 10000^{-2i/d} θi=100002i/d 是频率因子。
  • ⊕ \oplus 表示矩阵的直和(即将多个小矩阵拼接成一个大矩阵)。
(3)应用旋转矩阵

将旋转矩阵 R m R_m Rm 应用到词向量 x x x 上:

x ′ = R m ⋅ x x' = R_m \cdot x x=Rmx

具体来说,对于每一对 ( x i , x d / 2 + i ) (x_i, x_{d/2+i}) (xi,xd/2+i) ,旋转操作如下:

[ x i ′ x d / 2 + i ′ ] = [ cos ⁡ m θ i − sin ⁡ m θ i sin ⁡ m θ i cos ⁡ m θ i ] [ x i x d / 2 + i ] \begin{bmatrix} x'_i \\ x'_{d/2+i} \end{bmatrix} = \begin{bmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{bmatrix} \begin{bmatrix} x_i \\ x_{d/2+i} \end{bmatrix} [xixd/2+i]=[cosmθisinmθisinmθicosmθi][xixd/2+i]


5. 代码实现

以下是旋转位置编码的 PyTorch 实现:

import torch
import torch.nn as nn

class RotaryPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
        # 初始化频率因子 theta_i
        theta = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('theta', theta)
        
    def forward(self, x, seq_len):
        """
        x: (batch_size, seq_len, dim)
        seq_len: 序列长度
        """
        batch_size, seq_len, dim = x.shape
        
        # 生成位置索引 m
        m = torch.arange(seq_len, device=x.device).float()
        
        # 构造旋转矩阵 R_m
        freqs = torch.einsum('i,j->ij', m, self.theta)  # (seq_len, dim/2)
        freqs = torch.cat([freqs, freqs], dim=-1)  # (seq_len, dim)
        
        # 将 freqs 转换为 cos 和 sin
        cos = torch.cos(freqs).unsqueeze(0)  # (1, seq_len, dim)
        sin = torch.sin(freqs).unsqueeze(0)  # (1, seq_len, dim)
        
        # 旋转操作
        x1, x2 = x.chunk(2, dim=-1)
        x_rotated = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
        
        return x_rotated

6. 总结

  • 旋转位置编码 通过旋转矩阵将位置信息融入词向量,能够更好地建模词与词之间的相对位置关系。
  • 它的实现非常简单,只需要对词向量的每一对维度进行旋转即可。
  • 这种方法在长序列任务中表现优异,且计算效率高。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FOUR_A

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值