BERT_Dataset
import os
import random
import torch
from d2l import torch as d2l
#@save
d2l.DATA_HUB['wikitext-2'] = (
'https://s3.amazonaws.com/research.metamind.io/wikitext/'
'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
#@save
def _read_wiki(data_dir):
file_name = os.path.join(data_dir, 'wiki.train.tokens')
with open(file_name, encoding='utf-8') as f:
lines = f.readlines()
# 大写字母转换为小写字母
paragraphs = [line.strip().lower().split(' . ')
for line in lines if len(line.split(' . ')) >= 2]
random.shuffle(paragraphs)
return paragraphs
在WikiText-2数据集中,每行代表一个段落,其中在任意标点符号及其前面的词元之间插入空格。保留至少有两句话的段落。为了简单起见,我们仅使用句号作为分隔符来拆分句子。
生成下一句预测任务的数据样本(任务二)
#@save
def _get_next_sentence(sentence, next_sentence, paragraphs):
if random.random() < 0.5:
is_next = True
else:
# paragraphs是三重列表的嵌套
next_sentence = random.choice(random.choice(paragraphs))
is_next = False
return sentence, next_sentence, is_next
50%概率真的是下一个句子
50%随机选一个句子
is_next是标签
#@save
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
nsp_data_from_paragraph = []
for i in range(len(paragraph) - 1):
tokens_a, tokens_b, is_next = _get_next_sentence(
paragraph[i], paragraph[i + 1], paragraphs)
# 考虑1个'<cls>'词元和2个'<sep>'词元
if len(tokens_a) + len(tokens_b) + 3 > max_len:
continue

这段代码展示了如何从WikiText-2数据集中构建BERT的预训练任务数据,包括下一句预测(NSP)和遮蔽语言模型(MLM)任务。首先,读取并处理数据,然后生成这两个任务的样本。接着,对数据进行填充以适应固定长度,最后创建数据集实例。代码还包含了加载数据集的函数,用于生成批量数据迭代器。
最低0.47元/天 解锁文章
2718





