def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
结论: freqs_cis是最终得到的预相乘角度信息
理解过程:
为方便描述,设定Q/K的大小为[512, 64], 即长度为512, 每行64个元素
第一行代码,得到 [1, ..., 0.0] 按一定比例分配的弧度角度,如下,dim缩写为d
为什么除以2? 因为对于一行长度为dim = 64 的Q/K向量来说,在rope编码中是按照复数形式组织的
每一对复数对应一个编码角度, 总共就是dim/2个角度值,如图所示
第二行代码, 这个是得到长度的绝对位置信息,如下
为什么是最大长度512的两倍?因为在rope变换中,需要对Qm和Kn都进行编码
第三行代码, 计算得到相乘的编码角度信息,如下
为什么是矩阵乘,因为M其实是个向量,theta也是一个向量
相乘之后的矩阵如下:
这样就预计算得到了所有组合的 m * theta