Stanford CS336 | Assignment 1 - Transformer Language Model Architecture

所有关于 assignment1 的代码已开源在:
https://github.com/ACEEE-1222/Standford-CS336-Assignment-1
如果对你有帮助的话,记得顺手点个star喔!

作业要求见 https://github.com/stanford-cs336/assignment1-basics

作业1后半段要求从零实现一个基于Transformer的语言模型(LM)——这是理解现代大语言模型(LLM)内部机制的关键实践。

本文将详细拆解Transformer语言模型的完整实现过程,涵盖多头注意力、旋转位置编码(RoPE)、RMS归一化、SwiGLU前馈网络等核心组件,同时讲解自定义AdamW优化器、学习率调度、批量数据处理等训练辅助工具,帮助读者掌握从模型构建到训练落地的全流程。

一、项目概述

本次作业的核心目标是搭建一个仅含解码器的Transformer语言模型(类似GPT结构),使其具备预测序列中下一个token的能力。模型采用了当前主流的设计方案(如预归一化、RoPE、RMSNorm)以兼顾效率与性能,同时配套实现了完整的训练流水线,可直接在文本数据上进行优化。

模型与训练框架的核心特点:

  • 基于解码器的Transformer结构,包含多头自注意力机制
  • 采用旋转位置编码(RoPE),增强模型对序列位置信息的捕捉能力
  • 使用RMSNorm替代传统LayerNorm,提升训练稳定性
  • 前馈网络中引入SwiGLU激活函数,性能优于ReLU等传统激活
  • 自定义AdamW优化器、余弦学习率调度器与梯度裁剪模块
  • 支持训练 checkpoint 保存与加载,便于中断后续训

二、核心组件实现:transformer.py

transformer.py 文件包含了Transformer语言模型的所有核心模块,各组件设计遵循模块化原则,既便于调试,也为后续扩展预留了空间。

2.1 基础通用层

这类层是模型的"基础工具",在多个模块中被复用,负责实现最基本的张量变换操作。

2.1.1 无偏置线性层(Linear)

简化版的线性变换层,移除了偏置项(bias),并采用类Xavier初始化(截断正态分布)保证训练稳定性。前向传播通过einops库实现清晰的张量维度映射,避免手动reshape导致的维度混乱。

class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
        super().__init__()
        # 定义权重参数:形状为 (输出维度, 输入维度)
        self.weight = nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype))
        # 类Xavier初始化:标准差 = sqrt(2/(输入维度 + 输出维度)),避免梯度消失/爆炸
        std = (2 / (in_features + out_features)) ** 0.5
        nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 前向计算:y = x @ W^T(输入形状 ..., in_features → 输出形状 ..., out_features)
        return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")
2.1.2 词嵌入层(Embedding)

将离散的token ID映射为连续的稠密向量,嵌入维度(embedding_dim)与模型隐藏层维度(d_model)保持一致。权重同样采用截断正态分布初始化,确保初始嵌入向量的分布合理性。

class Embedding(nn.Module):
    def __init__(
        self,
        num_embeddings: int,  # 词汇表大小(即总token数)
        embedding_dim: int,   # 嵌入向量维度(需等于d_model)
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.vocab_size = num_embeddings
        self.d_model = embedding_dim
        # 嵌入权重矩阵:形状为 (词汇表大小, 嵌入维度)
        self.weight = nn.Parameter(torch.empty((self.vocab_size, self.d_model), device=device, dtype=dtype))
        nn.init.trunc_normal_(self.weight, std=1, a=-3, b=3)

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
        # 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, embedding_dim)
        return self.weight[token_ids]  # 通过索引直接获取对应token的嵌入向量
2.1.3 RMS归一化层(RMSNorm)

相比传统LayerNorm,RMSNorm移除了均值中心化步骤,仅对输入的均方根(RMS)进行归一化,在减少计算量的同时提升训练稳定性,是LLaMA、GPT-4等模型的默认归一化方案。

class RMSNorm(nn.Module):
    def __init__(
        self,
        d_model: int,          # 输入维度(需等于模型隐藏层维度)
        eps: float = 1e-5,     # 防止分母为0的微小值
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        # 可学习的缩放参数:形状为 (d_model,),初始化为1(不改变归一化结果)
        self.weight = nn.Parameter(torch.ones(self.d_model, device=device, dtype=dtype))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 输入:(batch_size, seq_len, d_model) → 输出:同输入形状
        in_dtype = x.dtype  # 保存输入数据类型,避免精度损失
        x = x.to(dtype=torch.float32)  # 转为float32计算,提升数值稳定性
        
        # 1. 计算最后一维的均方根(RMS)
        rms = (x.pow(2).mean(-1, keepdim=True) + self.eps) ** 0.5
        # 2. 归一化 + 应用缩放参数
        out = x / rms * self.weight
        
        return out.to(dtype=in_dtype)  # 恢复原数据类型

2.2 激活函数与前馈网络

前馈网络(FFN)是Transformer中负责"特征转换"的核心模块,而SwiGLU则是当前性能最优的激活函数之一,两者结合可显著提升模型的表达能力。

2.2.1 SwiGLU激活函数

SwiGLU是GLU(Gated Linear Unit)的变体,通过Sigmoid门控对线性变换结果进行筛选,相比ReLU能更好地捕捉特征间的非线性关系,同时避免梯度消失问题。

class SwiGLU(nn.Module):
    def __init__(
        self,
        d_model: int,    # 输入维度(模型隐藏层维度)
        d_ff: int,       # 前馈网络中间层维度(通常为d_model的4倍)
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        # 定义三个线性层:W1/W3用于生成门控与候选特征,W2用于输出投影
        self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
        self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
        self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)

    # 辅助函数:Sigmoid线性单元(SiLU)
    def _silu(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.sigmoid(x)

    # 辅助函数:门控线性单元(GLU)
    def _glu(self, x: torch.Tensor) -> torch.Tensor:
        return self._silu(self.w1(x)) * self.w3(x)  # SiLU门控 × W3线性变换结果

    def forward(self, x: torch.Tensor) -
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值