BERT 预训练数据构建全流程:从原始语料到 Parquet 数据集

【精选优质专栏推荐】


每个专栏均配有案例与图文讲解,循序渐进,适合新手与进阶学习者,欢迎订阅。

在这里插入图片描述

前言

BERT 是一种仅编码器(encoder-only)的 Transformer 模型,在针对各种 NLP 任务进行微调之前,会先在掩码语言模型(MLM)和下一句预测(NSP)任务上进行预训练。预训练需要特殊的数据准备方式。在本文中,你将学习如何:

  • 创建掩码语言模型(MLM)训练数据

  • 创建下一句预测(NSP)训练数据

  • 为 BERT 训练设置标签

  • 使用 Hugging Face datasets 来存储训练数据

本文分为四个部分,分别是:

  • 准备文档

  • 从文档中创建句子对

  • 掩码 Token

  • 保存训练数据以便复用

准备文档

与仅解码器模型不同,BERT 的预训练更加复杂。如上一篇文章所述,预训练会同时优化 MLM 和 NSP 任务的组合损失。因此,训练数据必须同时为这两个任务打上标签。

我们按照 Google 的 BERT 实现方式,使用 Wikitext-2 或 Wikitext-103 数据集。数据集中的每一行要么是空行、要么是以 “=” 开头的标题行、要么是普通文本。只有普通文本行会被用于训练。

BERT 的训练样本需要包含两个“句子”。为简单起见,定义如下:

“句子”是数据集中的一行文本
文档是由连续的“句子”组成的序列,不同文档之间由空行或标题行分隔

假设你已经像上一篇文章中那样训练好了一个 tokenizer,下面我们创建一个函数,将文本收集为文档列表:

import tokenizers
from datasets import load_dataset
 
def create_docs(path, name, tokenizer):
    """Load wikitext dataset and extract text as documents"""
    dataset = load_dataset(path, name, split="train")
    docs = []
    for line in dataset["text"]:
        line = line.strip()
        if not line or line.startswith("="):
            docs.append([])   # new document encountered
        else:
            tokens = tokenizer.encode(line).ids
            docs[-1].append(tokens)
    docs = [doc for doc in docs if doc]  # remove empty documents
    return docs
 
# load the tokenizer
tokenizer = tokenizers.Tokenizer.from_file("wikitext-103_wordpiece.json")
docs = create_docs("wikitext", "wikitext-103-raw-v1", tokenizer)

这段代码会按顺序处理文本行。当遇到文档分隔符时,就为后续的行创建一个新的列表。每一行都会以 tokenizer 输出的整数列表形式进行存储。

跟踪文档对于 NSP 任务至关重要:只有当两个句子都来自同一个文档时,它们才能构成一个“下一句”序列对。

从文档中创建句子对

下一步是从文档中抽取句子对。句子对可以是同一文档中的相邻句子,也可以是来自不同文档的随机句子。我们使用如下算法来创建句子对:

遍历每个文档中的每一个句子,将其作为第一个句子
对于第二个句子,要么选择同一文档中的下一句,要么从另一个文档中随机选择一句

但这里有一个约束条件:句子对的总长度不能超过 BERT 的最大序列长度。如有必要,需要对句子进行截断。

下面展示了如何用 Python 实现该算法:

...
import random
 
# copy the document
chunks = []
for chunk in all_docs[doc_idx]:
    chunks.append(chunk)
 
# exhaust chunks and create samples
while chunks:
    # scan until target token length
    running_length = 0
    end = 1
    while end < len(chunks) and running_length < target_length:
        running_length += len(chunks[end-1])
        end += 1
    # randomly separate the chunk into two segments
    sep = random.randint(1, end-1) if end > 1 else 1
    sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]
    sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]
    # sentence B: may be from another document
    if not sentence_b or random.random() < 0.5:
        # find another document (must not be the same as doc_idx)
        b_idx = random.randint(0, len(all_docs)-2)
        if b_idx >= doc_idx:
            b_idx += 1
        # sentence B starts from a random position in the new document
        sentence_b = []
        running_length = len(sentence_a)
        i = random.randint(0, len(all_docs[b_idx])-1)
        while i < len(all_docs[b_idx]) and running_length < target_length:
            sentence_b.extend(all_docs[b_idx][i])
            running_length += len(all_docs[b_idx][i])
            i += 1
        is_random_next = True
        chunks = chunks[sep:]
    else:
        is_random_next = False
        chunks = chunks[end:]
    # the pair is found
    pair = (sentence_a, sentence_b)

这段代码会针对索引为 doc_idx 的某个文档创建句子对。最开始的 for 循环会将句子复制为 chunks,以避免修改原始文档。while 循环会扫描 chunks,直到达到目标 token 长度,然后将其随机拆分为两个片段。

在 50% 的概率下,第二个句子会被替换为来自另一个文档的随机句子。这个较大的 if 代码块负责创建 NSP 任务的标签(记录在 is_random_next 中),并从另一个文档中采样一个句子。

在每一次迭代结束时,chunks 会被更新,仅保留尚未使用的部分。当该列表为空时,说明文档已被完全处理。sentence_a 和 sentence_b 都是由整数形式的 token ID 组成的列表。

这种方法遵循了 Google 原始的 BERT 实现方式,尽管它并不会穷举所有可能的组合。上述创建的句子对可能会超过目标序列长度,因此需要进行截断。截断的实现如下所示:

def truncate_seq_pair(sentence_a, sentence_b, max_num_tokens):
    while len(sentence_a) + len(sentence_b) > max_num_tokens:
        # pick the longer sentence to remove tokens from
        candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b
        # remove one token from either end in equal probabilities
        if random.random() < 0.5:
            candidate.pop(0)
        else:
            candidate.pop()

截断操作会反复应用在较长的句子上,直到总长度小于目标值为止。每次以相同的概率从句子的首部或尾部移除一个 token。最终结果可能是原始句子中间的一段片段,这也正是代码中使用 “chunk” 这一命名方式的原因。

掩码 Token

掩码 token 是 BERT 训练数据中最关键的部分。原始论文中规定有 15% 的 token 会被掩码。实际上,这意味着模型只在其输出中的 15% 位置上被训练去预测 token。在这 15% 中,token 可能会是以下三种情况之一:

  • 80% 的情况下,token 会被替换为 [MASK] token。

  • 10% 的情况下,token 会被替换为词表中的一个随机 token。

  • 10% 的情况下,token 保持不变。

在所有情况下,模型都必须正确预测原始 token。在创建好句子对之后,我们可以按如下方式实现掩码过程:

def create_sample(sentence_a, sentence_b, is_random_next, tokenizer,
                  max_seq_length=512, mask_prob=0.15, max_predictions_per_seq=20):
    # Collect id of special tokens
    cls_id = tokenizer.token_to_id("[CLS]")
    sep_id = tokenizer.token_to_id("[SEP]")
    mask_id = tokenizer.token_to_id("[MASK]")
    pad_id = tokenizer.padding["pad_id"]
    # adjust length to fit the max sequence length
    truncate_seq_pair(sentence_a, sentence_b, max_seq_length-3)
    num_pad = max_seq_length - len(sentence_a) - len(sentence_b) - 3
    # create unmodified tokens sequence
    tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)
    seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [-1] * num_pad
    assert len(tokens) == len(seg_id) == max_seq_length
    # create the prediction targets
    cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]
    random.shuffle(cand_indices)
    num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
    num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
    mlm_positions = sorted(cand_indices[:num_predictions])
    mlm_labels = []
    for i in mlm_positions:
        mlm_labels.append(tokens[i])
        # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original
        if random.random() < 0.8:
            tokens[i] = mask_id
        elif random.random() < 0.5:
            tokens[i] = random.randint(4, tokenizer.get_vocab_size()-1)
    # randomly mask some tokens
    ret = {
        "tokens": tokens,
        "segment_ids": seg_id,
        "is_random_next": is_random_next,
        "masked_positions": mlm_positions,
        "masked_labels": mlm_labels,
    }
    return ret

该函数会创建一个 token 序列:[CLS] <text_1> [SEP] <text_2> [SEP],其中 <text_1> 和 <text_2> 是包含掩码 token 的句子对。特殊 token 的 ID 由 tokenizer 提供。

首先,对句子对进行截断,使其适配最大序列长度,并为三个特殊 token 预留空间。随后将序列填充到期望长度。接着创建 segment 标签,用于区分句子:第一个句子为 0,第二个句子为 1,填充部分为 -1。

所有非特殊 token 都是可被掩码的候选项。它们的索引会被打乱,然后选取前 num_predictions 个位置。该数量由 mask_prob(默认 15%)决定,并且上限为 max_predictions_per_seq(默认 20):

num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
num_predictions = min(max_predictions_per_seq, max(1, num_predictions))

变量 mlm_positions 是被掩码位置的索引列表,且按升序排列。变量 mlm_labels 是这些被掩码位置上原始 token 的列表。当需要从词表中选择一个随机 token 时,可以通过 tokenizer 来获取:

tokens[i] = random.randint(4, tokenizer.get_vocab_size()-1)

词表中的前四个 token 是特殊 token,它们不会被选为掩码对象。最终返回的字典 ret 就是用于训练 BERT 模型的“样本”。

保存训练数据以便复用

到目前为止,你已经学习了如何将原始数据集处理为包含掩码 token 的句子对,以用于 MLM 和 NSP 训练任务。通过上述代码,你可以创建一个由字典组成的列表作为训练数据。然而,这种方式未必是最适合训练循环的数据供给方式。

在 PyTorch 代码中,提供训练数据的标准方式是使用 Dataset 类,也就是定义一个类似如下的类:

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
 
    def __len__(self):
        return len(self.data)
 
    def __getitem__(self, idx):
        return self.data[idx]

关键在于实现 lengetitem 方法,分别用于返回数据集中样本的总数量以及指定索引位置的样本。然而,这种方式对 BERT 训练来说可能并不是最优的,因为在初始化数据集类时,往往需要一次性将整个数据集加载到内存中。当数据集规模较大时,这种做法并不高效。

一种替代方案是使用 Hugging Face datasets 库中的 Dataset 类。它隐藏了许多数据管理的细节,使你能够将精力集中在更重要的事情上。假设你已经创建了一个用于生成样本的生成器函数:

def create_dataset(docs, tokenizer):
    ...
    for doc in docs:
        for sample in create_samples(doc):
            yield sample

你可以通过如下方式创建一个数据集对象:

...
from datasets import Dataset
 
dataset = Dataset.from_generator(create_dataset, gen_kwargs={"docs": docs, "tokenizer": tokenizer})
dataset.to_parquet("wikitext-103_train_data.parquet")

这两行代码会从生成器函数中拉取所有样本,并将它们保存为 parquet 格式的文件。根据数据集的大小,这个过程可能需要一些时间。gen_kwargs 是一个用于向生成器函数传递关键字参数的字典,其内容应当与你定义该函数时的参数保持一致。

一旦你将数据集保存为 parquet 格式,就可以将其重新加载,并尝试打印一些样本:

...
dataset = Dataset.from_parquet("wikitext-103_train_data.parquet", streaming=True)
for i, sample in enumerate(dataset):
    if i > 5:
        break
    print(sample)

这正是 parquet 格式发挥优势的地方。Hugging Face datasets 库同样支持 JSON 和 CSV 格式,但 parquet 是一种压缩的列式存储格式,在数据存储和检索方面更加高效。设置 streaming=True 是可选的,它允许你只加载当前正在使用的数据部分,而不是一次性将整个数据集加载到内存中。

将所有内容整合在一起,完整代码如下:

"""Process the WikiText dataset for training the BERT model. Using Hugging Face
datasets library.
"""
 
import time
import random
from typing import Iterator
 
import tokenizers
from datasets import load_dataset, Dataset
 
# path and name of each dataset
DATASETS = {
    "wikitext-2": ("wikitext", "wikitext-2-raw-v1"),
    "wikitext-103": ("wikitext", "wikitext-103-raw-v1"),
}
PATH, NAME = DATASETS["wikitext-103"]
TOKENIZER_PATH = "wikitext-103_wordpiece.json"
 
 
def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]:
    """Load wikitext dataset and extract text as documents"""
    dataset = load_dataset(path, name, split="train")
    docs: list[list[list[int]]] = []
    for line in dataset["text"]:
        line = line.strip()
        if not line or line.startswith("="):
            docs.append([])   # new document encountered
        else:
            tokens = tokenizer.encode(line).ids
            docs[-1].append(tokens)
    docs = [doc for doc in docs if doc]  # remove empty documents
    return docs
 
 
def create_dataset(
    docs: list[list[list[int]]],
    tokenizer: tokenizers.Tokenizer,
    max_seq_length: int = 512,
    doc_repeat: int = 10,
    mask_prob: float = 0.15,
    short_seq_prob: float = 0.1,
    max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
    """Generate samples from all documents"""
    doc_indices = list(range(len(docs))) * doc_repeat
    for doc_idx in doc_indices:
        yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq)
 
def generate_samples(
    doc_idx: int,
    all_docs: list[list[list[int]]],
    tokenizer: tokenizers.Tokenizer,
    max_seq_length: int = 512,
    mask_prob: float = 0.15,
    short_seq_prob: float = 0.1,
    max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
    """Generate samples from a given document"""
    # number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP]
    target_length = max_seq_length - 3
    if random.random() < short_seq_prob:
        # shorter sequence is used 10% of the time
        target_length = random.randint(2, target_length)
 
    # copy the document
    chunks = []
    for chunk in all_docs[doc_idx]:
        chunks.append(chunk)
 
    # exhaust chunks and create samples
    while chunks:
        # scan until target token length
        running_length = 0
        end = 1
        while end < len(chunks) and running_length < target_length:
            running_length += len(chunks[end-1])
            end += 1
        # randomly separate the chunk into two segments
        sep = random.randint(1, end-1) if end > 1 else 1
        sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]
        sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]
        # sentence B: may be from another document
        if not sentence_b or random.random() < 0.5:
            # find another document (must not be the same as doc_idx)
            b_idx = random.randint(0, len(all_docs)-2)
            if b_idx >= doc_idx:
                b_idx += 1
            # sentence B starts from a random position in the new document
            sentence_b = []
            running_length = len(sentence_a)
            i = random.randint(0, len(all_docs[b_idx])-1)
            while i < len(all_docs[b_idx]) and running_length < target_length:
                sentence_b.extend(all_docs[b_idx][i])
                running_length += len(all_docs[b_idx][i])
                i += 1
            is_random_next = True
            chunks = chunks[sep:]
        else:
            is_random_next = False
            chunks = chunks[end:]
        # create a sample from the pair
        yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq)
 
def create_sample(
    sentence_a: list[list[int]],
    sentence_b: list[list[int]],
    is_random_next: bool,
    tokenizer: tokenizers.Tokenizer,
    max_seq_length: int = 512,
    mask_prob: float = 0.15,
    max_predictions_per_seq: int = 20,
) -> dict:
    """Create a sample from a pair of sentences"""
    # Collect id of special tokens
    cls_id = tokenizer.token_to_id("[CLS]")
    sep_id = tokenizer.token_to_id("[SEP]")
    mask_id = tokenizer.token_to_id("[MASK]")
    pad_id = tokenizer.padding["pad_id"]
    # adjust length to fit the max sequence length
    truncate_seq_pair(sentence_a, sentence_b, max_seq_length-3)
    num_pad = max_seq_length - len(sentence_a) - len(sentence_b) - 3
    # create unmodified tokens sequence
    tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)
    seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [-1] * num_pad
    assert len(tokens) == len(seg_id) == max_seq_length
    # create the prediction targets
    cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]
    random.shuffle(cand_indices)
    num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
    num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
    mlm_positions = sorted(cand_indices[:num_predictions])
    mlm_labels = []
    for i in mlm_positions:
        mlm_labels.append(tokens[i])
        # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original
        if random.random() < 0.8:
            tokens[i] = mask_id
        elif random.random() < 0.5:
            tokens[i] = random.randint(4, tokenizer.get_vocab_size()-1)
    # randomly mask some tokens
    ret = {
        "tokens": tokens,
        "segment_ids": seg_id,
        "is_random_next": is_random_next,
        "masked_positions": mlm_positions,
        "masked_labels": mlm_labels,
    }
    return ret
 
 
def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None:
    """Truncate a pair of sequences until below a maximum sequence length."""
    while len(sentence_a) + len(sentence_b) > max_num_tokens:
        # pick the longer sentence to remove tokens from
        candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b
        # remove one token from either end in equal probabilities
        if random.random() < 0.5:
            candidate.pop(0)
        else:
            candidate.pop()
 
 
if __name__ == "__main__":
    print(time.time(), "started")
    tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
    print(time.time(), "loaded tokenizer")
    docs = create_docs(PATH, NAME, tokenizer)
    print(time.time(), "created docs with %d documents" % len(docs))
    dataset = Dataset.from_generator(create_dataset, gen_kwargs={"docs": docs, "tokenizer": tokenizer})
    print(time.time(), "created dataset from generator")
    # Save dataset to parquet file
    dataset.to_parquet("wikitext-103_train_data.parquet")
    print(time.time(), "saved dataset to parquet file")
    # Load dataset from parquet file
    dataset = Dataset.from_parquet("wikitext-103_train_data.parquet", streaming=True)
    print(time.time(), "loaded dataset from parquet file")
    # Print a few samples
    for i, sample in enumerate(dataset):
        print(i)
        print(sample)
        print()
        if i >= 3:
            break
    print(time.time(), "finished")

运行这段代码后,你将看到如下形式的输出:

1763913652.5099447 started
1763913652.5830114 loaded tokenizer
1763913817.1271229 created docs with 268854 documents
Generating train split: 4268307 examples [11:42:36, 101.25 examples/s]
Loading dataset shards: 100%|█████████████████████████████| 73/73 [00:10<00:00, 7.05it/s]
1763956111.2021146 created dataset from generator
Creating parquet from Arrow format: 100%|█████████████| 4269/4269 [06:13<00:00, 11.42ba/s]
1763956487.0040812 saved dataset to parquet file
Generating train split: 4268307 examples [06:22, 11168.96 examples/s]
Loading dataset shards: 100%|█████████████████████████████| 74/74 [00:09<00:00, 8.20it/s]
1763956881.6215432 loaded dataset from parquet file
0
{'tokens': [1, 10887, 4875, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': True, 'masked_positions': [29, 58, ...],
'masked_labels': [15, 8551, ...]}
 
1
{'tokens': [1, 8792, 9150, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': True, 'masked_positions': [15, 19, ...],
'masked_labels': [8522, 9100, ...]}
 
2
{'tokens': [1, 8506, 8556, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': False, 'masked_positions': [3, 8, ...],
'masked_labels': [19367, 29188, ...]}
 
3
{'tokens': [1, 8544, 8910, ..., 0, 0], 'segment_ids': [0, 0, ..., 1, 1, ..., -1, -1],
'is_random_next': False, 'masked_positions': [13, 16, ...],
'masked_labels': [8656, 12114, ...]}
 
1763956881.6688802 finished

中间不定期打印的时间戳是有意为之,用于展示各个阶段所消耗的时间。这段代码会处理 Wikitext-103 数据集,整体运行需要数小时。完成之后,生成的 parquet 文件可以支持对样本进行快速、高效的迭代。在测试阶段,你也可以改用规模更小的 Wikitext-2 数据集,这样可以在几分钟内看到代码的运行效果。

总结

在本文中,你学习了如何为 BERT 训练准备数据。你了解了如何创建掩码语言模型(MLM)训练数据以及下一句预测(NSP)训练数据。同时,你还学习了如何将数据保存为 parquet 格式以便复用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

秋说

感谢打赏

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值