绝对位置编码
旋转位置编码
论文中有个很直观的图片展示了旋转变换的过程:
对于“我”对应的d维向量, 拆分成d/2组以后,每组对应一个角度,若
1对应的向量为(x1,x2),应用旋转位置编码,相当于这个分量旋转了m
1角度。
结合transformer4.42.4版本,qwen2源码分析如下:
1、定义和缓存cos、sin
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
在上面这段代码中,inv_freq对应的是各分量的旋转角度,长度为d/2
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
这里的t为提前把所有可能的位置id 都先取好,并与对应的角度相乘,对应公式中的m,计算出来的矩阵freqs维度为(self.max_seq_len,d/2)。这里outer函数计算如下:
torch.outer
import torch
t = torch.tensor([1,2,3])
inv_freq = torch.tensor([0.1,0.2,0.3])
f = torch.outer(t, inv_freq)
tensor([[0.1000, 0.2000, 0.3000],
[0.2000, 0.4000, 0.6000],
[0.3000, 0.6000, 0.9000]])
emb = torch.cat((f,f), dim=-1)
emb
tensor([[0.1000, 0.2000, 0.3000, 0.1000, 0.2000, 0.3000],
[0.2000, 0.4000, 0.6000, 0.200