大模型旋转位置编码RoPE

一、概念

        旋转位置编码(Rotary Position Embedding,RoPE)是一种用于大模型(如Transformer)的位置编码方法。它通过旋转embedding向量来引入位置信息,从而在模型中捕捉序列的相对位置信息。RoPE在一些自然语言处理任务中表现出色,尤其是在处理长序列时。

        在Transformer模型中,位置编码(Position Embedding)用于引入序列中每个位置的位置信息,因为Transformer的自注意力机制本身是无序的。传统的绝对位置编码(如正弦和余弦位置编码)在处理长序列时可能会遇到一些问题,如无法有效捕捉相对位置信息。目前,诸如Llama、Qwen之类的大模型均在使用旋转位置嵌入方法,可见其重要性。

二、原理

        旋转位置编码的核心思想是通过旋转embedding向量来引入位置信息。

1、旋转变换

        假设embedding向量为e_{i},其维度为 d。RoPE将embedding向量中的特征分成两两一组,共d/2对,每组表示为(x_{1},x_{2})进行旋转变换。对于每个位置 i,定义一个角度\theta_{i},通过下面的变换对每对二维向量进行旋转:

\binom{x_{1}^{'}}{x_{2}^{'}} \rightarrow \binom{cos(\theta_{i})\ \ \ \ -sin(\theta_{i})}{sin(\theta_{i})\ \ \ \ \ cos(\theta_{i})} \binom{x_{1}}{x_{2}}

        其中,\theta_{i}是与位置 i 相关的旋转角度,通常是线性增长的。

2、旋转角度

        旋转角度\theta_{i}通常与位置 i 和embedding维度 d 相关,可以定义为:

\theta_{i} = \frac{i}{10000^{2k/d}}

        这种定义方式类似于正弦和余弦位置编码,但通过旋转变换引入了相对位置信息。

三、python实现

import torch
import torch.nn as nn
import math

# 生成旋转矩阵
def precompute_freqs_cis(dim, seq_len, theta=10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度θᵢ
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2]
    freqs = torch.ger(t, freqs).float()  # 计算m * θ
    # 计算结果是个复数向量
    # 假设 freqs = [x, y]
    # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

# 旋转位置编码计算
def apply_rotary_emb(xq, xk, freqs_cis):
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().view(xq.size(0), xq.size(1), -1, 2)
    xk_ = xk.float().view(xk.size(0), xk.size(1), -1, 2)
    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    # 应用旋转操作,然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).view(xq.size(0), xq.size(1), -1)
    xk_out = torch.view_as_real(xk_ * freqs_cis).view(xk.size(0), xk.size(1), -1)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# 示例
batch_size, seq_len, dim = 32, 128, 512
xq = torch.randn(batch_size, seq_len, dim)
xk = torch.randn(batch_size, seq_len, dim)
freqs_cis = precompute_freqs_cis(dim, seq_len)
xq_rotated, xk_rotated = apply_rotary_emb(xq, xk, freqs_cis)
print(xq_rotated.size())  # 输出: torch.Size([32, 128, 512])
print(xk_rotated.size())  # 输出: torch.Size([32, 128, 512])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值