所有关于 assignment1 的代码已开源在:
https://github.com/ACEEE-1222/Standford-CS336-Assignment-1
如果对你有帮助的话,记得顺手点个star喔!
作业要求见 https://github.com/stanford-cs336/assignment1-basics
作业1后半段要求从零实现一个基于Transformer的语言模型(LM)——这是理解现代大语言模型(LLM)内部机制的关键实践。
本文将详细拆解Transformer语言模型的完整实现过程,涵盖多头注意力、旋转位置编码(RoPE)、RMS归一化、SwiGLU前馈网络等核心组件,同时讲解自定义AdamW优化器、学习率调度、批量数据处理等训练辅助工具,帮助读者掌握从模型构建到训练落地的全流程。
一、项目概述
本次作业的核心目标是搭建一个仅含解码器的Transformer语言模型(类似GPT结构),使其具备预测序列中下一个token的能力。模型采用了当前主流的设计方案(如预归一化、RoPE、RMSNorm)以兼顾效率与性能,同时配套实现了完整的训练流水线,可直接在文本数据上进行优化。
模型与训练框架的核心特点:
- 基于解码器的Transformer结构,包含多头自注意力机制
- 采用旋转位置编码(RoPE),增强模型对序列位置信息的捕捉能力
- 使用RMSNorm替代传统LayerNorm,提升训练稳定性
- 前馈网络中引入SwiGLU激活函数,性能优于ReLU等传统激活
- 自定义AdamW优化器、余弦学习率调度器与梯度裁剪模块
- 支持训练 checkpoint 保存与加载,便于中断后续训
二、核心组件实现:transformer.py
transformer.py 文件包含了Transformer语言模型的所有核心模块,各组件设计遵循模块化原则,既便于调试,也为后续扩展预留了空间。
2.1 基础通用层
这类层是模型的"基础工具",在多个模块中被复用,负责实现最基本的张量变换操作。
2.1.1 无偏置线性层(Linear)
简化版的线性变换层,移除了偏置项(bias),并采用类Xavier初始化(截断正态分布)保证训练稳定性。前向传播通过einops库实现清晰的张量维度映射,避免手动reshape导致的维度混乱。
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
super().__init__()
# 定义权重参数:形状为 (输出维度, 输入维度)
self.weight = nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype))
# 类Xavier初始化:标准差 = sqrt(2/(输入维度 + 输出维度)),避免梯度消失/爆炸
std = (2 / (in_features + out_features)) ** 0.5
nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 前向计算:y = x @ W^T(输入形状 ..., in_features → 输出形状 ..., out_features)
return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")
2.1.2 词嵌入层(Embedding)
将离散的token ID映射为连续的稠密向量,嵌入维度(embedding_dim)与模型隐藏层维度(d_model)保持一致。权重同样采用截断正态分布初始化,确保初始嵌入向量的分布合理性。
class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int, # 词汇表大小(即总token数)
embedding_dim: int, # 嵌入向量维度(需等于d_model)
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.vocab_size = num_embeddings
self.d_model = embedding_dim
# 嵌入权重矩阵:形状为 (词汇表大小, 嵌入维度)
self.weight = nn.Parameter(torch.empty((self.vocab_size, self.d_model), device=device, dtype=dtype))
nn.init.trunc_normal_(self.weight, std=1, a=-3, b=3)
def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
# 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, embedding_dim)
return self.weight[token_ids] # 通过索引直接获取对应token的嵌入向量
2.1.3 RMS归一化层(RMSNorm)
相比传统LayerNorm,RMSNorm移除了均值中心化步骤,仅对输入的均方根(RMS)进行归一化,在减少计算量的同时提升训练稳定性,是LLaMA、GPT-4等模型的默认归一化方案。
class RMSNorm(nn.Module):
def __init__(
self,
d_model: int, # 输入维度(需等于模型隐藏层维度)
eps: float = 1e-5, # 防止分母为0的微小值
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.d_model = d_model
self.eps = eps
# 可学习的缩放参数:形状为 (d_model,),初始化为1(不改变归一化结果)
self.weight = nn.Parameter(torch.ones(self.d_model, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 输入:(batch_size, seq_len, d_model) → 输出:同输入形状
in_dtype = x.dtype # 保存输入数据类型,避免精度损失
x = x.to(dtype=torch.float32) # 转为float32计算,提升数值稳定性
# 1. 计算最后一维的均方根(RMS)
rms = (x.pow(2).mean(-1, keepdim=True) + self.eps) ** 0.5
# 2. 归一化 + 应用缩放参数
out = x / rms * self.weight
return out.to(dtype=in_dtype) # 恢复原数据类型
2.2 激活函数与前馈网络
前馈网络(FFN)是Transformer中负责"特征转换"的核心模块,而SwiGLU则是当前性能最优的激活函数之一,两者结合可显著提升模型的表达能力。
2.2.1 SwiGLU激活函数
SwiGLU是GLU(Gated Linear Unit)的变体,通过Sigmoid门控对线性变换结果进行筛选,相比ReLU能更好地捕捉特征间的非线性关系,同时避免梯度消失问题。
class SwiGLU(nn.Module):
def __init__(
self,
d_model: int, # 输入维度(模型隐藏层维度)
d_ff: int, # 前馈网络中间层维度(通常为d_model的4倍)
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
# 定义三个线性层:W1/W3用于生成门控与候选特征,W2用于输出投影
self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)
# 辅助函数:Sigmoid线性单元(SiLU)
def _silu(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
# 辅助函数:门控线性单元(GLU)
def _glu(self, x: torch.Tensor) -> torch.Tensor:
return self._silu(self.w1(x)) * self.w3(x) # SiLU门控 × W3线性变换结果
def forward(self, x: torch.Tensor) -

最低0.47元/天 解锁文章
1503

被折叠的 条评论
为什么被折叠?



