这段代码实现了一个简单的 GPT 模型,包括数据处理、模型构建、训练和文本生成等功能。以下是对代码的详细解释:
整体结构
代码主要分为以下几个部分:
- 导入必要的库:导入
tiktoken
用于文本编码,torch
用于深度学习计算,nn
用于构建神经网络模块,DataLoader
和Dataset
用于数据加载和处理。 - 数据集相关代码:定义了
GPTDatasetV1
类和create_dataloader_v1
函数,用于创建和加载数据集。 - 多头注意力机制:实现了
MultiHeadAttention
类,用于计算多头注意力。 - 归一化、激活函数和前馈网络:定义了
LayerNorm
、GELU
和FeedForward
类,分别用于层归一化、GELU 激活函数和前馈网络。 - Transformer 块和 GPT 模型:构建了
TransformerBlock
和GPTModel
类,用于组合模型的各个组件。 - 文本生成函数:实现了
generate_text_simple
函数,用于生成文本。 - 主函数:在
main
函数中,初始化模型、设置参数、进行文本生成,并输出结果。
代码详细解释
数据集相关
GPTDatasetV1
类:- 初始化函数:接收文本
txt
、分词器tokenizer
、最大长度max_length
和步长stride
作为参数。对输入文本进行分词,然后使用滑动窗口将文本分割成重叠的序列,每个序列长度为max_length
。 __len__
方法:返回数据集的样本数量。__getitem__
方法:根据索引返回输入序列和目标序列。
- 初始化函数:接收文本
create_dataloader_v1
函数:- 初始化
gpt2
分词器,创建GPTDatasetV1
数据集,然后使用DataLoader
加载数据集。
- 初始化
多头注意力机制
MultiHeadAttention
类:- 初始化函数:接收输入维度
d_in
、输出维度d_out
、上下文长度context_length
、丢弃率dropout
、头数num_heads
和偏置标志qkv_bias
作为参数。初始化线性层W_query
、W_key
、W_value
和out_proj
,以及掩码和丢弃层。 - 前向传播函数:计算查询、键和值矩阵,将其分割成多个头,计算注意力分数,应用掩码,计算注意力权重,最后计算上下文向量并进行投影。
- 初始化函数:接收输入维度
归一化、激活函数和前馈网络
LayerNorm
类:- 初始化函数:接收嵌入维度
emb_dim
作为参数,初始化缩放参数scale
和偏移参数shift
。 - 前向传播函数:计算输入的均值和方差,进行归一化,然后应用缩放和偏移。
- 初始化函数:接收嵌入维度
GELU
类:- 初始化函数:无参数。
- 前向传播函数:实现 GELU 激活函数的计算公式。
FeedForward
类:- 初始化函数:接收配置字典
cfg
作为参数,初始化包含两个线性层和一个 GELU 激活函数的前馈网络。 - 前向传播函数:将输入通过前馈网络。
- 初始化函数:接收配置字典
Transformer 块和 GPT 模型
TransformerBlock
类:- 初始化函数:接收配置字典
cfg
作为参数,初始化多头注意力层att
、前馈网络层ff
、两个层归一化层norm1
和norm2
,以及丢弃层drop_shortcut
。 - 前向传播函数:通过注意力层和前馈网络层,并添加残差连接。
- 初始化函数:接收配置字典
GPTModel
类:- 初始化函数:接收配置字典
cfg
作为参数,初始化词嵌入层tok_emb
、位置嵌入层pos_emb
、丢弃层drop_emb
、多个 Transformer 块trf_blocks
、最终的层归一化层final_norm
和输出层out_head
。 - 前向传播函数:将输入索引转换为嵌入,添加位置嵌入,通过 Transformer 块,进行归一化,最后通过输出层得到 logits。
- 初始化函数:接收配置字典
文本生成函数
generate_text_simple
函数:- 接收模型
model
、当前索引idx
、最大新生成令牌数max_new_tokens
和上下文大小context_size
作为参数。 - 在循环中,根据当前上下文生成预测,选择概率最高的索引,将其添加到当前索引序列中。
- 接收模型
主函数
main
函数:- 定义 GPT-124M 模型的配置参数。
- 初始化模型,设置随机种子,将模型设置为评估模式。
- 对输入文本进行编码,调用
generate_text_simple
函数生成文本,最后对生成的文本进行解码并输出。
运行代码
当脚本直接运行时(if __name__ == "__main__"
),main
函数会被执行,从而完成模型的初始化、文本生成和结果输出。
# This file collects all the relevant code that we covered thus far
# throughout Chapters 2-4.
# This file can be run as a standalone script.
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#####################################
# Chapter 2
#####################################
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.input_ids = []
self.target_ids = []
# Tokenize the entire text
token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
input_chunk = token_ids