Transformer中的相对位置编码(Relative Position Embedding)

本文介绍了Transformer中相对位置编码的概念,相对于绝对位置编码,它考虑了token间的相对距离,通过在self-attention计算中加入相对位置信息的表示,增强了模型对语序的捕获能力。文中详细解释了相对位置编码的公式推导,并提到在实际代码实现中,Hugging Face的库已支持该方法,但在大型模型中可能因引入额外参数而影响预训练权重的使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

回顾transformer中绝对位置编码(absolute position embedding)

在transformer的实现中,所有的input tokens是无序的,是没法像RNN的方法一样学到token之间的位置顺序关系,但是自然语言一定是有语序在里面的,所以原始的transformer的代码实现里就提供了一种非常简单的位置编码,输入的是一个和max_sequence_length一样长的固定index序列,比如在max_sequence_length=64的情况下,输入的位置信息序列就是[0,1,2,3…62,63],然后通过一个embedding层让网络自己去学习位置的编码和表征。因为不管输入的文本是什么,位置编码永远是一个定值的序列(在不同的网络中仅长度会变化),所以这就是一种绝对位置编码方式。我个人对这种位置编码方式是否能提供有用的位置信息是存疑的,因为当网络训练好了以后,所有的input embedding相当于都加上一个定值,然而不同的input里面的位置和语义信息都是不同的,我认为这种加定值的方式并不能学到一些位置带来的语义上的差别,但毕竟增加了一定的参数量,可能聊胜于无吧。
贴一下transformer中绝对位置编码的实现,在modeling_bert.py中,实现也很简单:

class BertEmbeddings(nn.Module):
   def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) #初始化定义position_embedding的embedding层,其实底层就是一个fc
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length] #生成定值的序列

        if position_ids.dtype is not torch.long:
            position_ids = position_ids.to(torch.long)

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        embeddings = inputs_embeds
            embeddings += token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids) #输入的固定序列经过一个fc得到位置编码,直接point-wise加回到token embedding上
            embeddings += position_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

相对位置编码方法详解和公式推导

相对位置编码的方法主要是出自于Google的一篇paper《Self-Attention with Relative Position Representations》
先说一下transformer中的self-attention机制:
每一层transformer由h个attention heads组成,其中每个attention head的输入为n个元素组成的token序列 x = ( x 1 , . . . , x n ) x = (x_1,...,x_n) x=(x1,...

### Transformer相对位置编码的改进方法 #### 背景介绍 Transformer 模型最初采用的是绝对位置编码 (Absolute Positional Encoding),即通过固定的位置索引来表示输入序列中的每个词。然而,在处理长序列时,绝对位置编码可能无法充分捕捉远距离上下文关系。因此,研究者提出了多种 **相对位置编码** 方法来解决这一问题。 #### 相对位置编码的基本概念 相对位置编码的核心思想在于利用序列中元素间的相对偏移量而非固定的绝对位置来进行建模[^1]。这种方法能够更好地捕获长程依赖关系,并减少因序列长度增加而导致的信息丢失风险。 #### 改进方法概述 以下是几种常见的相对位置编码改进方法及其对应的实现方式: --- #### 1. **基于注意力机制的相对位置编码** 该方法通过对自注意力机制进行扩展,将相对位置信息融入到注意力权重计算过程中。具体而言,对于任意两个 token \(i\) 和 \(j\),定义其相对位置为 \(\text{rel}_{ij} = j - i\)。然后将其映射至一个可学习的嵌入矩阵中。 ##### 实现细节 假设我们有一个大小为 \(d_{\text{model}}\) 的隐藏层维度,则可以构建如下形式的相对位置编码: \[ E_{\text{rel}}(k) = f(k), \quad k \in [-K, K], \] 其中 \(f(k)\) 表示某种函数(如正弦/余弦波形),\(K\) 是最大允许的距离范围。 在实际应用中,可以通过以下代码片段展示如何生成这些相对位置向量: ```python import torch import math def generate_relative_positions(max_len=512, d_model=512): positions = torch.arange(-max_len, max_len).float() / max_len * d_model embeddings = [] for pos in positions: embedding = [math.sin(pos / (10000**(2*i/d_model))) if i % 2 == 0 else \ math.cos(pos / (10000**(2*(i-1)/d_model))) for i in range(d_model)] embeddings.append(embedding) return torch.tensor(embeddings) relative_pos_embeddings = generate_relative_positions() print(relative_pos_embeddings.shape) # 输出形状应为 (2*max_len+1, d_model) ``` 此部分涉及的内容主要来源于早期关于 Transformer-XL 的工作。 --- #### 2. **旋转位置编码 (Rotary Position Embedding)** 另一种有效的改进方案称为“旋转位置编码”,它通过周期性的三角函数变换实现了动态调整的能力[^2]。相比于静态的绝对或相对位置编码,这种方式具有更强的表现力以及更少的记忆负担。 ##### 数学描述 给定一对 query (\(q_i\)) 及 key (\(k_j\)) 向量,我们可以分别对其进行分解并施加特定角度的旋转变换操作: \[ q'_i[k] = \begin{cases} q_i[k]\cos{\theta_k}, & \text{(even)} \\ -q_i[k]\sin{\theta_k}, & \text{(odd)} \end{cases}, \] 同理适用于 keys 的转换过程。这里 \(\theta_k=\frac{1}{T^{2k/d}}\) 定义了一个随频率变化的角度参数表。 上述逻辑可以用 PyTorch 编写成如下模块化结构: ```python class RotaryPositionEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() self.dim = dim inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) @staticmethod def rotate_half(x): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) def forward(self, seq_len): t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) # shape: [seq_len, dim] cos_emb, sin_emb = emb.cos(), emb.sin() return cos_emb[:, None, :], sin_emb[:, None, :] # Add batch dimension rot_pe = RotaryPositionEmbedding(dim=64) cosines, sines = rot_pe(seq_len=8) print(cosines.shape, sines.shape) # 应输出 [(8, 1, 64)] twice. ``` 这部分技术通常被用于视觉领域内的 ViTs 或 NLP 领域的大规模预训练模型当中。 --- #### 3. **乘积法 (Product Method)** 针对传统交叉法存在的不足之处——比如容易混淆不同方向上的相似间隔情况等问题,有学者提出了一种名为 “乘积法” 的新型策略[^4]。它的核心优势体现在两方面:一是显著降低了存储需求;二是增强了区分度较高的多维特征表达能力。 简单来说,就是把原本独立的一维坐标替换成了由多个因子共同决定的新体系。例如在一个二维平面上,假设有水平轴 L 和垂直轴 V ,那么对应某个点 P(i,j) 就能写作: \[ e_P[i][j]=e_L[i]*e_V[j]. \] 下面是一个简化版的例子演示如何创建这样的表格数据集: ```python def product_method_encoding(height=16, width=16, embed_dim=32): assert embed_dim % 2 == 0, 'Embedding size must be even.' half_d = embed_dim // 2 row_embs = torch.zeros([height, half_d]) col_embs = torch.zeros([width, half_d]) for h in range(height): row_embs[h] = [ math.sin(h/(10000**(r/half_d))) if r%2==0 else math.cos(h/(10000**((r-1)/half_d))) for r in range(half_d)] for w in range(width): col_embs[w] = [ math.sin(w/(10000**(c/half_d))) if c%2==0 else math.cos(w/(10000**((c-1)/half_d))) for c in range(half_d)] combined = torch.matmul(row_embs.unsqueeze(2), col_embs.T.unsqueeze(0)).view(height*width,-1) return combined encoded_matrix = product_method_encoding() print(encoded_matrix.size()) # Expected output: torch.Size([256, 32]) when height=width=16 and embed_dim=32 ``` 以上算法特别适合应用于图像分类或者目标检测任务之中。 --- ### 总结 综上所述,目前主流的三种相对位置编码改进方法分别为基于注意力机制的方法、旋转位置编码以及乘积法。每种都有各自的特点和适用场景,开发者可以根据具体的项目需求灵活选用合适的解决方案。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

今儿学习了没

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值