TR6 - Transformer实战 单词预测



理论知识

关于数据集 Wikitext-2

WikiText (The WikiText Long Term Dependency Language Modeling Dataset, 英语词库数据集)是一个包含1亿个词汇的英文数据集,这些词汇从Wikipedia的优质文章和标杆文章中提取得到。包含WikiText-2和WikiText-103两个版本,相比于著名的Penn Treebank(PTB)数据集来说的,前者是PTB词汇数量的2倍,后者是110倍。每个词汇还同时保留了产生该词汇的原始文章,尤其适合需要长时依赖(Long Term Dependency)的自然语言建模场景。

  • 数据来源:从维基百科抽取
  • 数据内容:包含维基百科的文章内容,包括各种主题和领域的信息。经过预处理和清洗,以提供干净和可用于训练的文本数据
  • 数据规模:Wikitext-2 包含了超过2088628个词标记文本,以及其中1915997个词标记(token)用于训练,172430个词标记用于验证和186716个词标记用于测试。
  • 数据格式:纯文本形式存储,每个文本文件包含一个维基百科文章的内容。文本以段落句子为单位进行分割。
  • 用途:通常用于语言建模任务,其中模型的目标是根据之前的上下文来预测下一个词或下一个句子。此外,还可以用于其他的文本生成任务,如机器翻译、摘要生成等。

模型结构

模型结构图

代码实现

0. 环境

pytorch: 2.1.0
torchtext: 0.16.0

1. 加载数据集

使用torchtext生成Wikitext-2数据集
batchify() 可以将数据排列成 batch_size 列。如果数据没有均匀地分成batch_size列,则会对数据进行修剪。
例如:将字母表作为数据(总长度是26),然后设置batch_size=4batchify会将字母表分成4个长度为6的序列,如图所示
batchify函数示意
由于torchtext已经停止更新了,源码里面的URL地址已经无法下载数据集,我们先从百度下载一个,地址为

https://aistudio.baidu.com/datasetdetail/230431

在当前目录下创建路径 datasets/WikiText2/ 然后将下载的wikitext-2-v1.zip放入这个文件夹

from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import dataset
from torch import nn, Tensor
import math, os, torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from tempfile import TemporaryDirectory

# 全局设备对象
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载训练集,创建词汇表
train_iter = WikiText2(split='train', root='.')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """将原始文本转换成扁平的张量"""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

def batchify(data: Tensor, bsz: int) -> Tensor:
    """将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素
    参数:
        data: Tensor, 形状为``[N]``
        bsz: int, 批大小
    返回:
        形状为 [N // bsz, bsz] 的张量
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len*bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

# 创建数据集
train_iter, val_iter, test_iter = WikiText2(root='.')
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)


batch_size = 20
eval_batch_size = 10

# 将三类数据集都处理成固定长度
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, batch_size)
test_data = batchify(test_data, batch_size)

# 编写数据集取值函数(就像CV里的data_loader一样)
bptt = 35

def get_batch(source: Tensor, i: int) -> tuple[Tensor, Tensor]:
    """获取批次数据
    参数:
        source: Tensor, 形状为 ``[full_seq_len, batch_size]``
        i: int, 当前批次索引
    返回:
        tuple(data, target),
        - data形状为[seq_len, batch_size]
        - target形状为[seq_len * batch_size]
    """
    # 计算当前批次的序列长度,最大为bptt,确保不超过source的长度
    seq_len = min(bptt, len(source) - 1 - i)
    # 获取data,从i开始,长度为seq_len
    data = source[i:i+seq_len]
    # 获取target,从i+1开始,长度为seq_len,并将其形状转换为一维张量
    target = 
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值