大模型旋转位置编码RoPE

一、概念

        旋转位置编码(Rotary Position Embedding,RoPE)是一种用于大模型(如Transformer)的位置编码方法。它通过旋转embedding向量来引入位置信息,从而在模型中捕捉序列的相对位置信息。RoPE在一些自然语言处理任务中表现出色,尤其是在处理长序列时。

        在Transformer模型中,位置编码(Position Embedding)用于引入序列中每个位置的位置信息,因为Transformer的自注意力机制本身是无序的。传统的绝对位置编码(如正弦和余弦位置编码)在处理长序列时可能会遇到一些问题,如无法有效捕捉相对位置信息。目前,诸如Llama、Qwen之类的大模型均在使用旋转位置嵌入方法,可见其重要性。

二、原理

        旋转位置编码的核心思想是通过旋转embedding向量来引入位置信息。

1、旋转变换

        假设embedding向量为e_{i},其维度为 d。RoPE将embedding向量中的特征分成两两一组,共d/2对,每组表示为(x_{1},x_{2})进行旋转变换。对于每个位置 i,定义一个角度\theta_{i},通过下面的变换对每对二维向量进行旋转:

\binom{x_{1}^{'}}{x_{2}^{'}} \rightarrow \binom{cos(\theta_{i})\ \ \ \ -sin(\theta_{i})}{sin(\theta_{i})\ \ \ \ \ cos(\theta_{i})} \binom{x_{1}}{x_{2}}

        其中,\theta_{i}是与位置 i 相关的旋转角度,通常是线性增长的。

2、旋转角度

        旋转角度\theta_{i}通常与位置 i 和embedding维度 d 相关,可以定义为:

\theta_{i} = \frac{i}{10000^{2k/d}}

        这种定义方式类似于正弦和余弦位置编码,但通过旋转变换引入了相对位置信息。

三、python实现

import torch
import torch.nn as nn
import math

# 生成旋转矩阵
def precompute_freqs_cis(dim, seq_len, theta=10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度θᵢ
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2]
    freqs = torch.ger(t, freqs).float()  # 计算m * θ
    # 计算结果是个复数向量
    # 假设 freqs = [x, y]
    # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

# 旋转位置编码计算
def apply_rotary_emb(xq, xk, freqs_cis):
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().view(xq.size(0), xq.size(1), -1, 2)
    xk_ = xk.float().view(xk.size(0), xk.size(1), -1, 2)
    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    # 应用旋转操作,然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).view(xq.size(0), xq.size(1), -1)
    xk_out = torch.view_as_real(xk_ * freqs_cis).view(xk.size(0), xk.size(1), -1)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# 示例
batch_size, seq_len, dim = 32, 128, 512
xq = torch.randn(batch_size, seq_len, dim)
xk = torch.randn(batch_size, seq_len, dim)
freqs_cis = precompute_freqs_cis(dim, seq_len)
xq_rotated, xk_rotated = apply_rotary_emb(xq, xk, freqs_cis)
print(xq_rotated.size())  # 输出: torch.Size([32, 128, 512])
print(xk_rotated.size())  # 输出: torch.Size([32, 128, 512])
以下是彩色图像的PSNR、SSIM、LPIPS和CIEDE2000评价算法的Matlab源码示例: 1. PSNR(峰值信噪比): ```matlab function psnr_value = PSNR(original, distorted) [M, N, ~] = size(original); mse = sum((original(:) - distorted(:)).^2) / (M * N * 3); max_value = max(original(:)); psnr_value = 10 * log10(max_value^2 / mse); end ``` 2. SSIM(结构相似性指数): ```matlab function ssim_value = SSIM(original, distorted) K1 = 0.01; K2 = 0.03; L = 255; C1 = (K1 * L)^2; C2 = (K2 * L)^2; original = double(original); distorted = double(distorted); mean_original = filter2(fspecial('gaussian', 11, 1.5), original, 'valid'); mean_distorted = filter2(fspecial('gaussian', 11, 1.5), distorted, 'valid'); var_original = filter2(fspecial('gaussian', 11, 1.5), original.^2, 'valid') - mean_original.^2; var_distorted = filter2(fspecial('gaussian', 11, 1.5), distorted.^2, 'valid') - mean_distorted.^2; cov_original_distorted = filter2(fspecial('gaussian', 11, 1.5), original .* distorted, 'valid') - mean_original .* mean_distorted; ssim_map = ((2 * mean_original .* mean_distorted + C1) .* (2 * cov_original_distorted + C2)) ./ ((mean_original.^2 + mean_distorted.^2 + C1) .* (var_original + var_distorted + C2)); ssim_value = mean2(ssim_map); end ``` 3. LPIPS(感知相似性指标):需要下载并使用LPIPS库,源码和使用说明可在https://github.com/richzhang/PerceptualSimilarity 找到。 4. CIEDE2000(CIE 2000色差公式):需要下载并使用CIEDE2000库,源码和使用说明可在https://www.mathworks.com/matlabcentral/fileexchange/46861-color-difference-cie-de2000 找到。 以上是基本的示例代码,用于评估图像质量的不同评价指标。你可以根据实际需求和图像数据进行适当的调整和修改。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值