rope编码代码分享

部署运行你感兴趣的模型镜像
from typing import Tuple
import torch

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Helper function to reshape frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)
#########填充维度,方便计算

def apply_rotary_emb(
        query: torch.Tensor,
        key: torch.Tensor,
        head_dim: int,
        max_seq_len: int,
        theta: float = 10000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    Args:
        query (torch.Tensor): Query tensor to apply rotary embeddings. Shape: (batch_size, seqlen, n_local_heads, head_dim)
        key (torch.Tensor): Key tensor to apply rotary embeddings. Shape: (batch_size, seqlen, n_local_kv_heads, head_dim)
        head_dim (int): Dimension of each attention head.
        max_seq_len (int): Maximum sequence length supported by model.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """

    _, seqlen, _, _ = query.shape  # 获取查询张量的形状参数
    device = query.device  # 获取查询张量的设备信息(如在 CPU 或 GPU 上)
    seq_len, batch_size, num_heads = query.size(1), query.size(0), query.size(2)  # 获取序列长度、批次大小和头部数量

    # reshape xq and xk to match the complex representation
    query_real, query_imag = query.float().reshape(query.shape[:-1] + (-1, 2)).unbind(-1)  # 将查询张量重塑并分为实部和虚部
    key_real, key_imag = key.float().reshape(key.shape[:-1] + (-1, 2)).unbind(-1)  # 将键张量重塑并分为实部和虚部

    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2.0, device=device) / head_dim))
    pos_seq = torch.arange(0, seqlen, device=device)
    sinusoid_inp = torch.einsum("i,j->ij", pos_seq, inv_freq)  #
    sin = torch.sin(sinusoid_inp)
    cos = torch.cos(sinusoid_inp)

    # Use the reshape_for_broadcast function to reshape cos and sin terms for broadcasting
    cos_rotations = reshape_for_broadcast(cos, query_real)  # 调整余弦值张量的形状以进行广播
    sin_rotations = reshape_for_broadcast(sin, query_imag)  # 调整正弦值张量的形状以进行广播

    # Apply the rotations to the real and imaginary parts
    query_rot_real = cos_rotations * query_real - sin_rotations * query_imag  # 应用旋转到查询张量的实部
    query_rot_imag = sin_rotations * query_real + cos_rotations * query_imag  # 应用旋转到查询张量的虚部
    key_rot_real = cos_rotations * key_real - sin_rotations * key_imag  # 应用旋转到键张量的实部
    key_rot_imag = sin_rotations * key_real + cos_rotations * key_imag  # 应用旋转到键张量的虚部

    # Reassemble the real and imaginary parts back into the original format
    # query_out = torch.cat([query_rot_real, query_rot_imag], dim=-1).view_as(query)  # 重新组合并调整查询张量的形状
    # key_out = torch.cat([key_rot_real, key_rot_imag], dim=-1).view_as(key)  # 重新组合并调整键张量的形状

    query_out = torch.stack((query_rot_real, query_rot_imag), dim=-1).flatten(-2)
    key_out = torch.stack((key_rot_real, key_rot_imag), dim=-1).flatten(-2)


    return query_out, key_out  # 返回包含旋转位置嵌入的查询和键张量

在这里插入图片描述

上述代码和b站这个up讲的,或者一般的rope代码有两点不同
1,q0,q1,q2,q3…和用两个相同cos,sin张量堆叠起来的新张量点乘的操作,变成先将张量q分离成q0,q2,q4…和q1,q3,q5…两个张量去和相同的cos,sin张量点乘
2,补全张量维度由代码

cos_cached idx_theta2.cos()[:,None,None,:]
sin_cached idx_theta2.sin()[:,None,None,:]

变成

cos_rotations = reshape_for_broadcast(cos, query_real)  # 调整余弦值张量的形状以进行广播
sin_rotations = reshape_for_broadcast(sin, query_imag)  # 调整正弦值张量的形状以进行广播

您可能感兴趣的与本文相关的镜像

Python3.11

Python3.11

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

### 关于 DeepSeek 开源项目的代码解析 #### 项目结构概览 DeepSeek 的开源项目通常遵循标准的 Python 项目布局,这有助于开发者快速理解整个项目的组织方式。典型的目录结构可能包括 `src` 文件夹用于存放核心源码、`tests` 文件夹放置单元测试脚本以及必要的配置文件如 `setup.py` 或者 `pyproject.toml` 来管理依赖关系和构建过程[^1]。 #### 主要组件介绍 - **MoE (Mixture of Experts)** 架构:这是 DeepSeek V3 版本中的核心技术之一,它允许模型根据不同输入动态调整内部专家网络的选择路径,从而提高效率并减少计算资源消耗。这种机制特别适合处理复杂的多模态数据集。 - **优化策略**:为了提升性能,DeepSeek 实现了一系列高效的训练技巧和技术,比如梯度累积、混合精度训练等方法来加速收敛速度而不损失准确性[^2]。 #### 配置与环境搭建指南 对于希望深入研究 DeepSeek 源码的新手来说,官方提供的安装指导文档是非常宝贵的参考资料。这些资料不仅涵盖了基本的软件包安装步骤,还包含了针对不同操作系统定制化的设置建议,确保每位贡献者都能顺利运行本地实例。 #### 调试技巧分享 当遇到难以定位的问题时,利用调试器(例如 PyCharm Professional Edition 中内置的那个)可以帮助更直观地跟踪程序执行流程;另外也可以考虑启用日志记录功能,通过分析输出的日志信息找出潜在错误所在之处。 ```python import logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def some_function(): logger.debug('This is a debug message.') ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值