Transformer细节(九)——Transformer位置编码的改进

一、相对位置编码

        相对位置编码是针对绝对位置编码的一种改进,旨在捕捉序列中元素之间的相对位置信息。相对位置编码在处理长距离依赖关系和泛化到不同长度的序列时表现更好。

1、工作原理

        相对位置编码的核心思想是,位置关系是相对的而不是绝对的。相对位置编码将位置差异(相对位置)纳入注意力计算中。

        假设有一个序列长度为 \( N \),位置 \( i \) 和位置 \( j \) 的相对位置编码可以表示为 \( r_{ij} \),并用于调整注意力得分 \( e_{ij} \)。

2、实现方式

(1)相对位置嵌入

        使用一个相对位置嵌入矩阵 \( W_r \) 来表示位置差异。位置 \( i \) 和位置 \( j \) 的相对位置编码 \( r_{ij} \) 可以通过查询这个嵌入矩阵得到。

   \[
   r_{ij} = W_r[j - i + N - 1]
   \]

        其中 \( N \) 是序列的最大长度,确保索引 \( j - i + N - 1 \) 在有效范围内。

(2)相对位置编码与注意力计算

### Transformer相对位置编码的改进方法 #### 背景介绍 Transformer 模型最初采用的是绝对位置编码 (Absolute Positional Encoding),即通过固定的位置索引来表示输入序列中的每个词。然而,在处理长序列时,绝对位置编码可能无法充分捕捉远距离上下文关系。因此,研究者提出了多种 **相对位置编码** 方法来解决这一问题。 #### 相对位置编码的基本概念 相对位置编码的核心思想在于利用序列中元素间的相对偏移量而非固定的绝对位置来进行建模[^1]。这种方法能够更好地捕获长程依赖关系,并减少因序列长度增加而导致的信息丢失风险。 #### 改进方法概述 以下是几种常见的相对位置编码改进方法及其对应的实现方式: --- #### 1. **基于注意力机制的相对位置编码** 该方法通过对自注意力机制进行扩展,将相对位置信息融入到注意力权重计算过程中。具体而言,对于任意两个 token \(i\) 和 \(j\),定义其相对位置为 \(\text{rel}_{ij} = j - i\)。然后将其映射至一个可学习的嵌入矩阵中。 ##### 实现细节 假设我们有一个大小为 \(d_{\text{model}}\) 的隐藏层维度,则可以构建如下形式的相对位置编码: \[ E_{\text{rel}}(k) = f(k), \quad k \in [-K, K], \] 其中 \(f(k)\) 表示某种函数(如正弦/余弦波形),\(K\) 是最大允许的距离范围。 在实际应用中,可以通过以下代码片段展示如何生成这些相对位置向量: ```python import torch import math def generate_relative_positions(max_len=512, d_model=512): positions = torch.arange(-max_len, max_len).float() / max_len * d_model embeddings = [] for pos in positions: embedding = [math.sin(pos / (10000**(2*i/d_model))) if i % 2 == 0 else \ math.cos(pos / (10000**(2*(i-1)/d_model))) for i in range(d_model)] embeddings.append(embedding) return torch.tensor(embeddings) relative_pos_embeddings = generate_relative_positions() print(relative_pos_embeddings.shape) # 输出形状应为 (2*max_len+1, d_model) ``` 此部分涉及的内容主要来源于早期关于 Transformer-XL 的工作。 --- #### 2. **旋转位置编码 (Rotary Position Embedding)** 另一种有效的改进方案称为“旋转位置编码”,它通过周期性的三角函数变换实现了动态调整的能力[^2]。相比于静态的绝对或相对位置编码,这种方式具有更强的表现力以及更少的记忆负担。 ##### 数学描述 给定一对 query (\(q_i\)) 及 key (\(k_j\)) 向量,我们可以分别对其进行分解并施加特定角度的旋转变换操作: \[ q'_i[k] = \begin{cases} q_i[k]\cos{\theta_k}, & \text{(even)} \\ -q_i[k]\sin{\theta_k}, & \text{(odd)} \end{cases}, \] 同理适用于 keys 的转换过程。这里 \(\theta_k=\frac{1}{T^{2k/d}}\) 定义了一个随频率变化的角度参数表。 上述逻辑可以用 PyTorch 编写成如下模块化结构: ```python class RotaryPositionEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() self.dim = dim inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) @staticmethod def rotate_half(x): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) def forward(self, seq_len): t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) # shape: [seq_len, dim] cos_emb, sin_emb = emb.cos(), emb.sin() return cos_emb[:, None, :], sin_emb[:, None, :] # Add batch dimension rot_pe = RotaryPositionEmbedding(dim=64) cosines, sines = rot_pe(seq_len=8) print(cosines.shape, sines.shape) # 应输出 [(8, 1, 64)] twice. ``` 这部分技术通常被用于视觉领域内的 ViTs 或 NLP 领域的大规模预训练模型当中。 --- #### 3. **乘积法 (Product Method)** 针对传统交叉法存在的不足之处——比如容易混淆不同方向上的相似间隔情况等问题,有学者提出了一种名为 “乘积法” 的新型策略[^4]。它的核心优势体现在两方面:一是显著降低了存储需求;二是增强了区分度较高的多维特征表达能力。 简单来说,就是把原本独立的一维坐标替换成了由多个因子共同决定的新体系。例如在一个二维平面上,假设有水平轴 L 和垂直轴 V ,那么对应某个点 P(i,j) 就能写作: \[ e_P[i][j]=e_L[i]*e_V[j]. \] 下面是一个简化版的例子演示如何创建这样的表格数据集: ```python def product_method_encoding(height=16, width=16, embed_dim=32): assert embed_dim % 2 == 0, 'Embedding size must be even.' half_d = embed_dim // 2 row_embs = torch.zeros([height, half_d]) col_embs = torch.zeros([width, half_d]) for h in range(height): row_embs[h] = [ math.sin(h/(10000**(r/half_d))) if r%2==0 else math.cos(h/(10000**((r-1)/half_d))) for r in range(half_d)] for w in range(width): col_embs[w] = [ math.sin(w/(10000**(c/half_d))) if c%2==0 else math.cos(w/(10000**((c-1)/half_d))) for c in range(half_d)] combined = torch.matmul(row_embs.unsqueeze(2), col_embs.T.unsqueeze(0)).view(height*width,-1) return combined encoded_matrix = product_method_encoding() print(encoded_matrix.size()) # Expected output: torch.Size([256, 32]) when height=width=16 and embed_dim=32 ``` 以上算法特别适合应用于图像分类或者目标检测任务之中。 --- ### 总结 综上所述,目前主流的三种相对位置编码改进方法分别为基于注意力机制的方法、旋转位置编码以及乘积法。每种都有各自的特点和适用场景,开发者可以根据具体的项目需求灵活选用合适的解决方案。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值