Transformer 模型最通俗易懂的实现(基于快递系统的比喻)
在这个代码中,我们将 Transformer 模型比作一个快递系统:
- 位置编码(Positional Encoding):就像为每个快递包裹贴上顺序号标签,告诉系统包裹的先后顺序。
- 多头注意力机制(Multi-Head Attention):就像多个快递员同时工作,每个快递员各自检查并匹配包裹信息,关注不同的细节(例如包裹外观、重量、地址等)。
- 前馈神经网络(FeedForward):相当于包裹在配送中心经过深入加工、质量检测、重新包装等流程,使信息更加完善,但包裹的总大小保持不变。
- 编码器层(EncoderLayer):将整个分拣中心看作一个层,负责接受客户寄件的包裹,根据包裹之间的信息关系进行分类、整合。
- 解码器层(DecoderLayer):就像配送中心的一层,用于把经过编码整理后的包裹按照物流配送要求分发出去。
- Transformer:整合上述所有环节,构成一个完整的快递运输系统。
下面的代码展示了各个部分的详细实现。
import torch
import torch.nn as nn
import math
# 1. 位置编码:为每个快递包裹贴上序号标签
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
"""
参数:
d_model: 每个包裹的特征维度
max_seq_length: 包裹的最大数量
功能:
为每个包裹添加位置编码,确保系统知道包裹的顺序
"""
super().__init__()
# 创建一个序号:0, 1, 2, ... 表示包裹顺序
position = torch.arange(max_seq_length).unsqueeze(1)
# 计算用于缩放的位置因子(分别对偶数、奇数维度采用 sin 和 cos)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# 构建位置编码矩阵,形状为 [max_seq_length, d_model]
pe = torch.zeros(max_seq_length, d_model)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度使用 sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度使用 cos
# 注册为缓冲区(不参与训练更新)
self.register_buffer('pe', pe)
def forward(self, x):
"""
参数:
x: 输入包裹特征,形状为 [batch_size, seq_length, d_model]
返回:
将位置编码加到输入上的结果,相当于为每个包裹添加了序号标签
"""
return x + self.pe[:x.size(1)]
# 2. 多头注意力机制:多个快递员同时工作、关注不同特征
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
参数:
d_model: 每个包裹的特征维度
num_heads: 快递员数量(注意力头数量)
要求:
d_model 必须能被 num_heads 整除,保证每个快递员获得相同维度的包裹信息
"""
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# 定义三个线性层,分别生成查询(Q)、键(K)和值(V),相当于为快递员分发不同工具
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
# 一个输出层,用于整合所有快递员的处理结果,恢复成整个包裹的特征
self.output_linear = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
"""
参数:
q, k, v: 输入包裹特征(通常 q=k=v),形状为 [batch_size, seq_length, d_model]
mask: (可选)用于屏蔽不需要关注的包裹信息
功能:
模拟多个快递员同时根据不同角度处理包裹信息,输出经过共同整合后的特征
"""
batch_size = q.size(0)
# 将包裹特征通过线性变换后,拆分成多头,每个快递员处理分到的部分特征,形状调整为 [batch_size, num_heads, seq_length, head_dim]
q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力分数,表示每个快递员对不同包裹之间相似度的评分:scores = Q * K^T / sqrt(head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 如果设置了mask,则将需要屏蔽的信息得分降为极小值
if mask is not None:
scores = scores