大模型学习 (Datawhale_Happy-LLM)笔记5: 搭建一个 Transformer
搭建 Transformer 的核心组件总结
1. 基础功能模块
- 自注意力机制:通过
QKV 矩阵计算序列内依赖关系,公式为:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk
QKT)V
- 多头注意力:将特征拆分为多个子空间并行计算注意力,增强模型表达能力
- 位置编码:通过正弦余弦函数为序列添加位置信息,解决 Transformer 无序列感知问题
2. 核心网络层
- 层归一化 (LayerNorm):对每个样本的所有特征维度归一化,公式为:
LayerNorm(x)=α⊙x−μσ2+ϵ+β\text{LayerNorm}(x) = \alpha \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \betaLayerNorm(x)=α⊙σ2+ϵ
x−μ+β
- 前馈神经网络 (MLP):由
Linear+ReLU+Linear 构成,增强模型非线性表达能力
3. 编码器与解码器架构
- 编码器层:由
LayerNorm+MultiHeadAttention+残差连接 和 LayerNorm+MLP+残差连接 组成
- 解码器层:比编码器多一个
掩码多头注意力,用于避免预测时看到未来信息
- 堆叠结构:多个编码器/解码器层堆叠形成深度网络
4. 输入输出处理
- 嵌入层:将离散token转换为连续向量,与位置编码相加后输入网络
- 输出层:通过线性层将特征映射到词表空间,用于生成概率分布
5. 关键技术点
- 残差连接:解决深度网络训练梯度消失问题
- Dropout:随机丢弃神经元,防止过拟合
- 掩码机制:在解码器中屏蔽未来位置,确保预测时的因果关系
完整 Transformer 架构流程图
输入序列 ──→ 嵌入层 + 位置编码 ──→ 编码器序列 ──→ 编码器输出
│
↓
解码器序列(含掩码)
│
↓
线性层 + Softmax ─→ 输出序列
多头自注意力模块
class ModelArgs:
def __init__(self):
self.n_embed = 256
self.n_head = 4
self.head_dim = self.n_embed // self.n_heads
self.dropout = 0.1
self.max_seq_len = 512
self.n_layers = 6
self.vocab_size = None
self.block_size = None
class MultiHeadAttention(nn.Module):
def __init__(self, args: ModelArgs, is_causal=False):
super().__init__()
assert args.n_embed % args.n_head == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout