LLama3 rope的pre_compute_freqs_cis函数理解

参考:Meta最新模型LLaMA细节与代码详解

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

 结论: freqs_cis是最终得到的预相乘角度信息

理解过程:

为方便描述,设定Q/K的大小为[512, 64], 即长度为512, 每行64个元素

max_{len} = 512, dim = 64

第一行代码,得到 [1, ..., 0.0] 按一定比例分配的弧度角度,如下,dim缩写为d

[\Theta_0, \Theta_1, \Theta_2, ..., \Theta_{d-1}]

为什么除以2? 因为对于一行长度为dim = 64 的Q/K向量来说,在rope编码中是按照复数形式组织的

[x_0, y_0, x_1, y_1, ..., x_{d/2-1}, y_{d/2-1}]

每一对复数对应一个编码角度\Theta_n, 总共就是dim/2个角度值,如图所示

第二行代码,  这个是得到长度的绝对位置信息,如下

\begin{bmatrix} m_0\\ m_1\\ ... \\ m_{512-1} \\n_0\\n_1\\...\\n_{512-1}\\ \end{bmatrix}

为什么是最大长度512的两倍?因为在rope变换中,需要对Qm和Kn都进行编码

第三行代码, 计算得到相乘的编码角度信息,如下

e^{jM\Theta}

 为什么是矩阵乘,因为M其实是个向量,theta也是一个向量

M = [m_0, m_1, m_2, ..., m_{512-1}]

\Theta = [\Theta_0, \Theta_1, ..., \Theta_{d/2-1}]

 相乘之后的矩阵如下:

\begin{bmatrix} \\ e^{m_0\Theta_0}, e^{m_0\Theta_1}, e^{m_0\Theta_2}, ..., e^{m_0\Theta_{d/2-1}} \\ e^{m_1\Theta_0}, e^{m_1\Theta_1}, e^{m_1\Theta_2}, ..., e^{m_1\Theta_{d/2-1}} \\ ... \\ e^{m_{1024-1}\Theta_0}, e^{m_{1024-1}\Theta_1},e^{ m_{1024-1}\Theta_2}, ..., e^{m_{1024-1}\Theta_{d/2-1}} \end{bmatrix}

这样就预计算得到了所有组合的 m * theta

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值