- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
理论知识
关于数据集 Wikitext-2
WikiText (The WikiText Long Term Dependency Language Modeling Dataset, 英语词库数据集)是一个包含1亿个词汇的英文数据集,这些词汇从Wikipedia的优质文章和标杆文章中提取得到。包含WikiText-2和WikiText-103两个版本,相比于著名的Penn Treebank(PTB)数据集来说的,前者是PTB词汇数量的2倍,后者是110倍。每个词汇还同时保留了产生该词汇的原始文章,尤其适合需要长时依赖(Long Term Dependency)的自然语言建模场景。
- 数据来源:从维基百科抽取
- 数据内容:包含维基百科的文章内容,包括各种主题和领域的信息。经过预处理和清洗,以提供干净和可用于训练的文本数据
- 数据规模:Wikitext-2 包含了超过2088628个词标记文本,以及其中1915997个词标记(token)用于训练,172430个词标记用于验证和186716个词标记用于测试。
- 数据格式:纯文本形式存储,每个文本文件包含一个维基百科文章的内容。文本以段落句子为单位进行分割。
- 用途:通常用于语言建模任务,其中模型的目标是根据之前的上下文来预测下一个词或下一个句子。此外,还可以用于其他的文本生成任务,如机器翻译、摘要生成等。
模型结构
代码实现
0. 环境
pytorch: 2.1.0
torchtext: 0.16.0
1. 加载数据集
使用torchtext生成Wikitext-2数据集
batchify()
可以将数据排列成 batch_size
列。如果数据没有均匀地分成batch_size
列,则会对数据进行修剪。
例如:将字母表作为数据(总长度是26),然后设置batch_size=4
,batchify
会将字母表分成4个长度为6的序列,如图所示
由于torchtext已经停止更新了,源码里面的URL地址已经无法下载数据集,我们先从百度下载一个,地址为
https://aistudio.baidu.com/datasetdetail/230431
在当前目录下创建路径 datasets/WikiText2/
然后将下载的wikitext-2-v1.zip
放入这个文件夹
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import dataset
from torch import nn, Tensor
import math, os, torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from tempfile import TemporaryDirectory
# 全局设备对象
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载训练集,创建词汇表
train_iter = WikiText2(split='train', root='.')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
"""将原始文本转换成扁平的张量"""
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
def batchify(data: Tensor, bsz: int) -> Tensor:
"""将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素
参数:
data: Tensor, 形状为``[N]``
bsz: int, 批大小
返回:
形状为 [N // bsz, bsz] 的张量
"""
seq_len = data.size(0) // bsz
data = data[:seq_len*bsz]
data = data.view(bsz, seq_len).t().contiguous()
return data.to(device)
# 创建数据集
train_iter, val_iter, test_iter = WikiText2(root='.')
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)
batch_size = 20
eval_batch_size = 10
# 将三类数据集都处理成固定长度
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, batch_size)
test_data = batchify(test_data, batch_size)
# 编写数据集取值函数(就像CV里的data_loader一样)
bptt = 35
def get_batch(source: Tensor, i: int) -> tuple[Tensor, Tensor]:
"""获取批次数据
参数:
source: Tensor, 形状为 ``[full_seq_len, batch_size]``
i: int, 当前批次索引
返回:
tuple(data, target),
- data形状为[seq_len, batch_size]
- target形状为[seq_len * batch_size]
"""
# 计算当前批次的序列长度,最大为bptt,确保不超过source的长度
seq_len = min(bptt, len(source) - 1 - i)
# 获取data,从i开始,长度为seq_len
data = source