所有关于 assignment1 的代码已开源在:
https://github.com/ACEEE-1222/Standford-CS336-Assignment-1
如果对你有帮助的话,记得顺手点个star喔!
本文记录了我在 Stanford CS336 第一次作业中完成的一项任务:从头实现一个 字节级 BPE(Byte Pair Encoding)分词器训练器,并支持高效的并行预分词与合并操作优化。本文将介绍 BPE 分词器的实现细节、并行预处理策略、特殊符号的处理方式,以及对 merge 步骤的优化。
一、作业要求概述
本次任务目标是在 TinyStories 数据集上训练一个字节级 BPE 分词器。如图所示,需要完成函数train_bpe。

核心要求如下:
- 初始化:以所有可能的字节(0-255)作为初始词汇表;
- 预处理:将文本分割为初始令牌(字节序列);
- 合并:统计所有字节对的频率,合并最频繁的对并更新词汇表,重复此过程直至达到目标词汇量。
二、整体实现结构
整个分词器训练流程主要分为以下三步:
1. 初始化词表
- 首先将所有 ASCII 字节值(0~255)加入初始词表。
- 然后添加用户指定的
special_tokens,如<|endoftext|>、<pad>等。
# 1. Vocabulary Initialization
vocab = {
i: bytes([i]) for i in range(256)}
for tok in special_tokens:
vocab[len(vocab)] = tok.encode("utf-8")
2. 并行预分词(Pre-tokenization)
(1)文档边界切分
为了并行处理,我们需要将文件划分为若干个 chunk,每个 chunk 的起始位置应当正好落在一个 <|endoftext|> 标记上,确保不会跨文档切分。我们使用 find_chunk_boundaries 函数来实现这一功能:
boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")
该函数会从初始估计位置开始,向前读取直到找到一个 <|endoftext|> 再作为真正边界。
(2)分发任务到子进程
我们利用 Python 的 multiprocessing.Pool,为每个 chunk 分发 process_chunk 任务:
task_args = [(input_path, start, end, special_tokens) for start, end in zip(boundaries[:-1], boundaries[1:])]
with Pool(processes=num_processes) as pool:
chunk_results = pool.map(process_chunk, task_args)
每个 chunk 的处理逻辑包括:
- 用正则表达式
re.split去除所有特殊 token,并按文档分别处理; - 使用 GPT-2 的正则表达式模式进行预分词;
- 每个 token 转为字节流,再按字节分割成
[b1, b2, ...]的形式。
3. BPE 合并训练
训练阶段按照以下流程进行:
- 初始化
counts和pair_to_indices,分别用于记录 pair 频率和其在哪些 token 中出现。 - 不断选择出现频率最高的 pair
(a, b)进行合并,生成新 token。 - 将受影响的所有 token 中的旧 pair 信息删除,并插入新 pair。
- 直到词表大小达到目标
vocab_size。
优化关键点在于:
- 只更新受到影响的 pair 计数,而不是每次全局重新统计;
- 避免在迭代过程中修改原集合,采用
.copy()确保稳定性。
示例代码片段(合并过程):
需要注意的是:在 affected_indices 用 pair_to_indices[max_pair].copy() 创建避免集合被动态修改。以及我一开始将 for j in affected_indices 写成三次遍历,导致pair_to_indices[max_pair]中间状态交叉污染,一次遍历就不会存在这样的问题(可以想想为什么)。
# 3. Compute BPE merges
merges : list[tuple[bytes, bytes]] = []
pre_tokens_bytes: list[list[bytes]] = [token for chunk in chunk_results for token in chunk]
counts = defaultdict(int)
pair_to_indices = defaultdict(set)
for idx, token in enumerate(pre_tokens_bytes):
for i in range(len(token) - 1):
pair = (token[i], token[i + 1])
counts[pair] += 1
pair_to_indices[pair].add(idx)
idx = len(vocab)
while idx < vocab_size:
if not counts:
break
max_pair: tuple[bytes, bytes] = None
max_cnt= -1
for pair, cnt in counts.items():
if cnt > max_cnt:
max_pair = pair
max_cnt = cnt
elif cnt == max_cnt:
if max_pair is None or pair > max_pair:
max_pair = pair
merges.append(max_pair)
a, b = max_pair
new_token = a + b
vocab[idx] = new_token
idx += 1
affected_indices = pair_to_indices[max_pair].copy()
for j in affected_indices:
token = pre_tokens_bytes[j]
for i in range(len(token) - 1):
old_pair = (token[i], token[i+1

最低0.47元/天 解锁文章
1290





