
本文将带你从零构建类GPT模型:通过实现层归一化、前馈网络和Transformer块等核心组件,打造一个完整的文本生成模型架构,为后续训练奠定基础。
目录
一、GPT模型架构全景图
1.1 模型组件分解

1.2 GPT-2模型规格
| 模型版本 | 参数量 | 层数 | 头数 | 隐藏维度 | 上下文长度 |
|---|---|---|---|---|---|
| GPT-2 Small | 124M | 12 | 12 | 768 | 1024 |
| GPT-2 Medium | 355M | 24 | 16 | 1024 | 1024 |
| GPT-2 Large | 774M | 36 | 20 | 1280 | 1024 |
| GPT-2 XL | 1.5B | 48 | 25 | 1600 | 1024 |
二、层归一化实现
2.1 为什么需要层归一化?
# 未归一化的激活值问题
activations = torch.randn(1000, 768) * 10 # 模拟大方差激活
mean = activations.mean(dim=1) # 各样本均值差异大
std = activations.std(dim=1) # 各样本标准差差异大
print("均值范围:", mean.min().item(), "~", mean.max().item())
print("标准差范围:", std.min().item(), "~", std.max().item())
层归一化优势:
-
稳定训练过程
-
加速收敛速度
-
缓解梯度消失/爆炸问题
-
减少对初始化的依赖
2.2 层归一化实现代码
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(d_model)) # 缩放参数
self.beta = nn.Parameter(torch.zeros(d_model)) # 平移参数
def forward(self, x):
# 计算均值和方差
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True, unbiased=False)
# 归一化
x_normalized = (x - mean) / (std + self.eps)
# 缩放和平移
return self.gamma * x_normalized + self.beta
# 与PyTorch官方实现对比测试
def test_layernorm():
input_tensor = torch.randn(2, 3, 768)
# 自定义层归一化
custom_ln = LayerNorm(768)
custom_out = custom_ln(input_tensor)
# PyTorch官方层归一化
official_ln = nn.LayerNorm(768)
official_out = official_ln(input_tensor)
# 检查差异
diff = (custom_out - official_out).abs().max().item()
print(f"最大差异: {diff:.6f}") # 应小于1e-5
test_layernorm()
三、前馈神经网络实现
3.1 GPT中的前馈结构
3.2 GELU激活函数
class GELU(nn.Module):
"""高斯误差线性单元激活函数"""
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
3.3 完整前馈网络实现
class FeedForward(nn.Module):
def __init__(self, d_model, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4 * d_model), # 扩展维度
GELU(), # 使用自定义GELU
nn.Linear(4 * d_model, d_model), # 降回原维度
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

最低0.47元/天 解锁文章
:从头实现GPT模型——构建文本生成引擎&spm=1001.2101.3001.5002&articleId=148528461&d=1&t=3&u=a51bc67c08ea492f9c3de98bc0e2cc2d)
1941

被折叠的 条评论
为什么被折叠?



