【精选优质专栏推荐】
- 《AI 技术前沿》 —— 紧跟 AI 最新趋势与应用
- 《网络安全新手快速入门(附漏洞挖掘案例)》 —— 零基础安全入门必看
- 《BurpSuite 入门教程(附实战图文)》 —— 渗透测试必备工具详解
- 《网安渗透工具使用教程(全)》 —— 一站式工具手册
- 《CTF 新手入门实战教程》 —— 从题目讲解到实战技巧
- 《前后端项目开发(新手必知必会)》 —— 实战驱动快速上手
每个专栏均配有案例与图文讲解,循序渐进,适合新手与进阶学习者,欢迎订阅。

前言
BERT 是一种仅编码器(encoder-only)的 Transformer 模型,在针对各种 NLP 任务进行微调之前,会先在掩码语言模型(MLM)和下一句预测(NSP)任务上进行预训练。预训练需要特殊的数据准备方式。在本文中,你将学习如何:
-
创建掩码语言模型(MLM)训练数据
-
创建下一句预测(NSP)训练数据
-
为 BERT 训练设置标签
-
使用 Hugging Face datasets 来存储训练数据
本文分为四个部分,分别是:
-
准备文档
-
从文档中创建句子对
-
掩码 Token
-
保存训练数据以便复用
准备文档
与仅解码器模型不同,BERT 的预训练更加复杂。如上一篇文章所述,预训练会同时优化 MLM 和 NSP 任务的组合损失。因此,训练数据必须同时为这两个任务打上标签。
我们按照 Google 的 BERT 实现方式,使用 Wikitext-2 或 Wikitext-103 数据集。数据集中的每一行要么是空行、要么是以 “=” 开头的标题行、要么是普通文本。只有普通文本行会被用于训练。
BERT 的训练样本需要包含两个“句子”。为简单起见,定义如下:
“句子”是数据集中的一行文本
文档是由连续的“句子”组成的序列,不同文档之间由空行或标题行分隔
假设你已经像上一篇文章中那样训练好了一个 tokenizer,下面我们创建一个函数,将文本收集为文档列表:
import tokenizers
from datasets import load_dataset
def create_docs(path, name, tokenizer):
"""Load wikitext dataset and extract text as documents"""
dataset = load_dataset(path, name, split="train")
docs = []
for line in dataset["text"]:
line = line.strip()
if not line or line.startswith("="):
docs.append([]) # new document encountered
else:
tokens = tokenizer.encode(line).ids
docs[-1].append(tokens)
docs = [doc for doc in docs if doc] # remove empty documents
return docs
# load the tokenizer
tokenizer = tokenizers.Tokenizer.from_file("wikitext-103_wordpiece.json")
docs = create_docs("wikitext", "wikitext-103-raw-v1", tokenizer)
这段代码会按顺序处理文本行。当遇到文档分隔符时,就为后续的行创建一个新的列表。每一行都会以 tokenizer 输出的整数列表形式进行存储。
跟踪文档对于 NSP 任务至关重要:只有当两个句子都来自同一个文档时,它们才能构成一个“下一句”序列对。
从文档中创建句子对
下一步是从文档中抽取句子对。句子对可以是同一文档中的相邻句子,也可以是来自不同文档的随机句子。我们使用如下算法来创建句子对:
遍历每个文档中的每一个句子,将其作为第一个句子
对于第二个句子,要么选择同一文档中的下一句,要么从另一个文档中随机选择一句
但这里有一个约束条件:句子对的总长度不能超过 BERT 的最大序列长度。如有必要,需要对句子进行截断。
下面展示了如何用 Python 实现该算法:
...
import random
# copy the document
chunks = []
for chunk in all_docs[doc_idx]:
chunks.append(chunk)
# exhaust chunks and create samples
while chunks:
# scan until target token length
running_length = 0
end = 1
while end < len(chunks) and running_length < target_length:
running_length += len(chunks[end-1])
end += 1
# randomly separate the chunk into two segments
sep = random.randint(1, end-1) if end > 1 else 1
sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]
sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]
# sentence B: may be from another document
if not sentence_b or random.random() < 0.5:
# find another document (must not be the same as doc_idx)
b_idx = random.randint(0, len(all_docs)-2)
if b_idx >= doc_idx:
b_idx += 1
# sentence B starts from a random position in the new document
sentence_b = []
running_length = len(sentence_a)
i = random.randint(0, len(all_docs[b_idx])-1)
while i < len(all_docs[b_idx]) and running_length < target_length:
sentence_b.extend(all_docs[b_idx][i])
running_length += len(all_docs[b_idx][i])
i += 1
is_random_next = True
chunks = chunks[sep:]
else:
is_random_next = False
chunks = chunks[end:]
# the pair is found
pair = (sentence_a, sentence_b)
这段代码会针对索引为 doc_idx 的某个文档创建句子对。最开始的 for 循环会将句子复制为 chunks,以避免修改原始文档。while 循环会扫描 chunks,直到达到目标 token 长度,然后将其随机拆分为两个片段。
在 50% 的概率下,第二个句子会被替换为来自另一个文档的随机句子。这个较大的 if 代码块负责创建 NSP 任务的标签(记录在 is_random_next 中),并从另一个文档中采样一个句子。
在每一次迭代结束时,chunks 会被更新,仅保留尚未使用的部分。当该列表为空时,说明文档已被完全处理。sentence_a 和 sentence_b 都是由整数形式的 token ID 组成的列表。
这种方法遵循了 Google 原始的 BERT 实现方式,尽管它并不会穷举所有可能的组合。上述创建的句子对可能会超过目标序列长度,因此需要进行截断。截断的实现如下所示:
def truncate_seq_pair(sentence_a, sentence_b, max_num_tokens):
while len(sentence_a) + len(sentence_b) > max_num_tokens:
# pick the longer sentence to remove tokens from
candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b
# remove one token from either end in equal probabilities
if random.random() < 0.5:
candidate.pop(0)
else:
candidate.pop()
截断操作会反复应用在较长的句子上,直到总长度小于目标值为止。每次以相同的概率从句子的首部或尾部移除一个 token。最终结果可能是原始句子中间的一段片段,这也正是代码中使用 “chunk” 这一命名方式的原因。
掩码 Token
掩码 token 是 BERT 训练数据中最关键的部分。原始论文中规定有 15% 的 token 会被掩码。实际上,这意味着模型只在其输出中的 15% 位置上被训练去预测 token。在这 15% 中,token 可能会是以下三种情况之一:
-
80% 的情况下,token 会被替换为 [MASK] token。
-
10% 的情况下,token 会被替换为词表中的一个随机 token。
-
10% 的情况下,token 保持不变。
在所有情况下,模型都必须正确预测原始 token。在创建好句子对之后,我们可以按如下方式实现掩码过程:
def create_sample(sentence_a, sentence_b, is_random_next, tokenizer,
max_seq_length=512, mask_prob=0.15, max_predictions_per_seq=20):
# Collect id of special tokens
cls_id = tokenizer.token_to_id("[CLS]")
sep_id = tokenizer.token_to_id("[SEP]")
mask_id = tokenizer.token_to_id("[MASK]")
pad_id = tokenizer.padding["pad_id"]
# adjust length to fit the max sequence length
truncate_seq_pair(sentence_a, sentence_b, max_seq_length-3)
num_pad = max_seq_length - len(sentence_a) - len(sentence_b) - 3
# create unmodified tokens sequence
tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)
seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [-1] * num_pad
assert len(tokens) == len(seg_id) == max_seq_length
# create the prediction targets
cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]
random.shuffle(cand_indices)
num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
mlm_positions = sorted(cand_indices[:num_predictions])
mlm_labels = []
for i in mlm_positions:
mlm_labels.append(tokens[i])
# prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original
if random.random() < 0.8:
tokens[i] = mask_id
elif random.random() < 0.5:
tokens[i] = random.randint(4, tokenizer.get_vocab_size()-1)
# randomly mask some tokens
ret = {
"tokens": tokens,
"segment_ids": seg_id,
"is_random_next": is_random_next,
"masked_positions": mlm_positions,
"masked_labels": mlm_labels,
}
return ret
该函数会创建一个 token 序列:[CLS] <text_1> [SEP] <text_2> [SEP],其中 <text_1> 和 <text_2> 是包含掩码 token 的句子对。特殊 token 的 ID 由 tokenizer 提供。
首先,对句子对进行截断,使其适配最大序列长度,并为三个特殊 token 预留空间。随后将序列填充到期望长度。接着创建 segment 标签,用于区分句子:第一个句子为 0,第二个句子为 1,填充部分为 -1。
所有非特殊 token 都是可被掩码的候选项。它们的索引会被打乱,然后选取前 num_predictions 个位置。该数量由 mask_prob(默认 15%)决定,并且上限为 max_predictions_per_seq(默认 20):
num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
变量 mlm_positions 是被掩码位置的索引列表,且按升序排列。变量 mlm_labels 是这些被掩码位置上原始 token 的列表。当需要从词表中选择一个随机 token 时,可以通过 tokenizer 来获取:
tokens[i] = random.randint(4, tokenizer.get_vocab_size()-1)
词表中的前四个 token 是特殊 token,它们不会被选为掩码对象。最终返回的字典 ret 就是用于训练 BERT 模型的“样本”。
保存训练数据以便复用
到目前为止,你已经学习了如何将原始数据集处理为包含掩码 token 的句子对,以用于 MLM 和 NSP 训练任务。通过上述代码,你可以创建一个由字典组成的列表作为训练数据。然而,这种方式未必是最适合训练循环的数据供给方式。
在 PyTorch 代码中,提供训练数据的标准方式是使用 Dataset 类,也就是定义一个类似如下的类:
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
关键在于实现 len 和 getitem 方法,分别用于返回数据集中样本的总数量以及指定索引位置的样本。然而,这种方式对 BERT 训练来说可能并不是最优的,因为在初始化数据集类时,往往需要一次性将整个数据集加载到内存中。当数据集规模较大时,这种做法并不高效。
一种替代方案是使用 Hugging Face datasets 库中的 Dataset 类。它隐藏了许多数据管理的细节,使你能够将精力集中在更重要的事情上。假设你已经创建了一个用于生成样本的生成器函数:
def create_dataset(docs, tokenizer):
...
for doc in docs:
for sample in create_samples(doc):
yield sample
你可以通过如下方式创建一个数据集对象:
...
from datasets import Dataset
dataset = Dataset.from_generator(create_dataset, gen_kwargs={"docs": docs, "tokenizer": tokenizer})
dataset.to_parquet("wikitext-103_train_data.parquet")
这两行代码会从生成器函数中拉取所有样本,并将它们保存为 parquet 格式的文件。根据数据集的大小,这个过程可能需要一些时间。gen_kwargs 是一个用于向生成器函数传递关键字参数的字典,其内容应当与你定义该函数时的参数保持一致。
一旦你将数据集保存为 parquet 格式,就可以将其重新加载,并尝试打印一些样本:
...
dataset = Dataset.from_parquet("wikitext-103_train_data.parquet", streaming=True)
for i, sample in enumerate(dataset):
if i > 5:
break
print(sample)
这正是 parquet 格式发挥优势的地方。Hugging Face datasets 库同样支持 JSON 和 CSV 格式,但 parquet 是一种压缩的列式存储格式,在数据存储和检索方面更加高效。设置 streaming=True 是可选的,它允许你只加载当前正在使用的数据部分,而不是一次性将整个数据集加载到内存中。
将所有内容整合在一起,完整代码如下:
"""Process the WikiText dataset for training the BERT model. Using Hugging Face
datasets library.
"""
import time
import random
from typing import Iterator
import tokenizers
from datasets import load_dataset, Dataset
# path and name of each dataset
DATASETS = {
"wikitext-2": ("wikitext", "wikitext-2-raw-v1"),
"wikitext-103": ("wikitext", "wikitext-103-raw-v1"),
}
PATH, NAME = DATASETS["wikitext-103"]
TOKENIZER_PATH = "wikitext-103_wordpiece.json"
def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]:
"""Load wikitext dataset and extract text as documents"""
dataset = load_dataset(path, name, split="train")
docs: list[list[list[int]]] = []
for line in dataset["text"]:
line = line.strip()
if not line or line.startswith("="):
docs.append([]) # new document encountered
else:
tokens = tokenizer.encode(line).ids
docs[-1].append(tokens)
docs = [doc for doc in docs if doc] # remove empty documents
return docs
def create_dataset(
docs: list[list[list[int]]],
tokenizer: tokenizers.Tokenizer,
max_seq_length: int = 512,
doc_repeat: int = 10,
mask_prob: float = 0.15,
short_seq_prob: float = 0.1,
max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
"""Generate samples from all documents"""
doc_indices = list(range(len(docs))) * doc_repeat
for doc_idx in doc_indices:
yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq)
def generate_samples(
doc_idx: int,
all_docs: list[list[list[int]]],
tokenizer: tokenizers.Tokenizer,
max_seq_length: int = 512,
mask_prob: float = 0.15,
short_seq_prob: float = 0.1,
max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
"""Generate samples from a given document"""
# number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP]
target_length = max_seq_length - 3
if random.random() < short_seq_prob:
# shorter sequence is used 10% of the time
target_length = random.randint(2, target_length)
# copy the document
chunks = []
for chunk in all_docs[doc_idx]:
chunks.append(chunk)
# exhaust chunks and create samples
while chunks:
# scan until target token length
running_length = 0
end = 1
while end < len(chunks) and running_length < target_length:
running_length += len(chunks[end-1])
end += 1
# randomly separate the chunk into two segments
sep = random.randint(1, end-1) if end > 1 else 1
sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]
sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]
# sentence B: may be from another document
if not sentence_b or random.random() < 0.5:
# find another document (must not be the same as doc_idx)
b_idx = random.randint(0, len(all_docs)-2)
if b_idx >= doc_idx:
b_idx += 1
# sentence B starts from a random position in the new document
sentence_b = []
running_length = len(sentence_a)
i = random.randint(0, len(all_docs[b_idx])-1)
while i < len(all_docs[b_idx]) and running_length < target_length:
sentence_b.extend(all_docs[b_idx][i])
running_length += len(all_docs[b_idx][i])
i += 1
is_random_next = True
chunks = chunks[sep:]
else:
is_random_next = False
chunks = chunks[end:]
# create a sample from the pair
yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq)
def create_sample(
sentence_a: list[list[int]],
sentence_b: list[list[int]],
is_random_next: bool,
tokenizer: tokenizers.Tokenizer,
max_seq_length: int = 512,
mask_prob: float = 0.15,
max_predictions_per_seq: int = 20,
) -> dict:
"""Create a sample from a pair of sentences"""
# Collect id of special tokens
cls_id = tokenizer.token_to_id("[CLS]")
sep_id = tokenizer.token_to_id("[SEP]")
mask_id = tokenizer.token_to_id("[MASK]")
pad_id = tokenizer.padding["pad_id"]
# adjust length to fit the max sequence length
truncate_seq_pair(sentence_a, sentence_b, max_seq_length-3)
num_pad = max_seq_length - len(sentence_a) - len(sentence_b) - 3
# create unmodified tokens sequence
tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)
seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [-1] * num_pad
assert len(tokens) == len(seg_id) == max_seq_length
# create the prediction targets
cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]
random.shuffle(cand_indices)
num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
mlm_positions = sorted(cand_indices[:num_predictions])
mlm_labels = []
for i in mlm_positions:
mlm_labels.append(tokens[i])
# prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original
if random.random() < 0.8:
tokens[i] = mask_id
elif random.random() < 0.5:
tokens[i] = random.randint(4, tokenizer.get_vocab_size()-1)
# randomly mask some tokens
ret = {
"tokens": tokens,
"segment_ids": seg_id,
"is_random_next": is_random_next,
"masked_positions": mlm_positions,
"masked_labels": mlm_labels,
}
return ret
def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None:
"""Truncate a pair of sequences until below a maximum sequence length."""
while len(sentence_a) + len(sentence_b) > max_num_tokens:
# pick the longer sentence to remove tokens from
candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b
# remove one token from either end in equal probabilities
if random.random() < 0.5:
candidate.pop(0)
else:
candidate.pop()
if __name__ == "__main__":
print(time.time(), "started")
tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
print(time.time(), "loaded tokenizer")
docs = create_docs(PATH, NAME, tokenizer)
print(time.time(), "created docs with %d documents" % len(docs))
dataset = Dataset.from_generator(create_dataset, gen_kwargs={"docs": docs, "tokenizer": tokenizer})
print(time.time(), "created dataset from generator")
# Save dataset to parquet file
dataset.to_parquet("wikitext-103_train_data.parquet")
print(time.time(), "saved dataset to parquet file")
# Load dataset from parquet file
dataset = Dataset.from_parquet("wikitext-103_train_data.parquet", streaming=True)
print(time.time(), "loaded dataset from parquet file")
# Print a few samples
for i, sample in enumerate(dataset):
print(i)
print(sample)
print()
if i >= 3:
break
print(time.time(), "finished")
运行这段代码后,你将看到如下形式的输出:
1763913652.5099447 started
1763913652.5830114 loaded tokenizer
1763913817.1271229 created docs with 268854 documents
Generating train split: 4268307 examples [11:42:36, 101.25 examples/s]
Loading dataset shards: 100%|█████████████████████████████| 73/73 [00:10<00:00, 7.05it/s]
1763956111.2021146 created dataset from generator
Creating parquet from Arrow format: 100%|█████████████| 4269/4269 [06:13<00:00, 11.42ba/s]
1763956487.0040812 saved dataset to parquet file
Generating train split: 4268307 examples [06:22, 11168.96 examples/s]
Loading dataset shards: 100%|█████████████████████████████| 74/74 [00:09<00:00, 8.20it/s]
1763956881.6215432 loaded dataset from parquet file
0
{'tokens': [1, 10887, 4875, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': True, 'masked_positions': [29, 58, ...],
'masked_labels': [15, 8551, ...]}
1
{'tokens': [1, 8792, 9150, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': True, 'masked_positions': [15, 19, ...],
'masked_labels': [8522, 9100, ...]}
2
{'tokens': [1, 8506, 8556, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': False, 'masked_positions': [3, 8, ...],
'masked_labels': [19367, 29188, ...]}
3
{'tokens': [1, 8544, 8910, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': False, 'masked_positions': [13, 16, ...],
'masked_labels': [8656, 12114, ...]}
1763956881.6688802 finished
中间不定期打印的时间戳是有意为之,用于展示各个阶段所消耗的时间。这段代码会处理 Wikitext-103 数据集,整体运行需要数小时。完成之后,生成的 parquet 文件可以支持对样本进行快速、高效的迭代。在测试阶段,你也可以改用规模更小的 Wikitext-2 数据集,这样可以在几分钟内看到代码的运行效果。
总结
在本文中,你学习了如何为 BERT 训练准备数据。你了解了如何创建掩码语言模型(MLM)训练数据以及下一句预测(NSP)训练数据。同时,你还学习了如何将数据保存为 parquet 格式以便复用。
522

被折叠的 条评论
为什么被折叠?



