TR6 - Transformer实战 单词预测



理论知识

关于数据集 Wikitext-2

WikiText (The WikiText Long Term Dependency Language Modeling Dataset, 英语词库数据集)是一个包含1亿个词汇的英文数据集,这些词汇从Wikipedia的优质文章和标杆文章中提取得到。包含WikiText-2和WikiText-103两个版本,相比于著名的Penn Treebank(PTB)数据集来说的,前者是PTB词汇数量的2倍,后者是110倍。每个词汇还同时保留了产生该词汇的原始文章,尤其适合需要长时依赖(Long Term Dependency)的自然语言建模场景。

  • 数据来源:从维基百科抽取
  • 数据内容:包含维基百科的文章内容,包括各种主题和领域的信息。经过预处理和清洗,以提供干净和可用于训练的文本数据
  • 数据规模:Wikitext-2 包含了超过2088628个词标记文本,以及其中1915997个词标记(token)用于训练,172430个词标记用于验证和186716个词标记用于测试。
  • 数据格式:纯文本形式存储,每个文本文件包含一个维基百科文章的内容。文本以段落句子为单位进行分割。
  • 用途:通常用于语言建模任务,其中模型的目标是根据之前的上下文来预测下一个词或下一个句子。此外,还可以用于其他的文本生成任务,如机器翻译、摘要生成等。

模型结构

模型结构图

代码实现

0. 环境

pytorch: 2.1.0
torchtext: 0.16.0

1. 加载数据集

使用torchtext生成Wikitext-2数据集
batchify() 可以将数据排列成 batch_size 列。如果数据没有均匀地分成batch_size列,则会对数据进行修剪。
例如:将字母表作为数据(总长度是26),然后设置batch_size=4batchify会将字母表分成4个长度为6的序列,如图所示
batchify函数示意
由于torchtext已经停止更新了,源码里面的URL地址已经无法下载数据集,我们先从百度下载一个,地址为

https://aistudio.baidu.com/datasetdetail/230431

在当前目录下创建路径 datasets/WikiText2/ 然后将下载的wikitext-2-v1.zip放入这个文件夹

from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import dataset
from torch import nn, Tensor
import math, os, torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from tempfile import TemporaryDirectory

# 全局设备对象
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载训练集,创建词汇表
train_iter = WikiText2(split='train', root='.')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """将原始文本转换成扁平的张量"""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

def batchify(data: Tensor, bsz: int) -> Tensor:
    """将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素
    参数:
        data: Tensor, 形状为``[N]``
        bsz: int, 批大小
    返回:
        形状为 [N // bsz, bsz] 的张量
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len*bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

# 创建数据集
train_iter, val_iter, test_iter = WikiText2(root='.')
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)


batch_size = 20
eval_batch_size = 10

# 将三类数据集都处理成固定长度
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, batch_size)
test_data = batchify(test_data, batch_size)

# 编写数据集取值函数(就像CV里的data_loader一样)
bptt = 35

def get_batch(source: Tensor, i: int) -> tuple[Tensor, Tensor]:
    """获取批次数据
    参数:
        source: Tensor, 形状为 ``[full_seq_len, batch_size]``
        i: int, 当前批次索引
    返回:
        tuple(data, target),
        - data形状为[seq_len, batch_size]
        - target形状为[seq_len * batch_size]
    """
    # 计算当前批次的序列长度,最大为bptt,确保不超过source的长度
    seq_len = min(bptt, len(source) - 1 - i)
    # 获取data,从i开始,长度为seq_len
    data = source
### 关于CNN-BiGRU-Transformer预测模型的流程图与架构 #### CNN-BiGRU-Transformer预测模型概述 CNN-BiGRU-Transformer是一种混合深度学习架构,结合了卷积神经网络(CNN)、双向门控循环单元(BiGRU)以及Transformer的优点。该架构能够有效处理多维度时间序列数据,在特征提取、时序建模和全局依赖关系捕捉方面表现出色[^1]。 #### 架构组成详解 1. **CNN模块** 卷积层负责从输入数据中提取局部空间特征。通过滑动窗口机制,CNN可以高效捕获图像或信号中的边缘、纹理等低级特征[^3]。 2. **BiGRU模块** 双向门控循环单元(BiGRU)用于对时间序列数据进行前向和反向建模。相比传统的RNN,GRU减少了计算开销并缓解了梯度消失问题。BiGRU则进一步增强了对上下文信息的理解能力。 3. **Transformer模块** Transformer主要承担全局依赖关系的学习任务。其自注意力机制允许模型关注不同位置的重要特征,从而实现更深层次的信息交互[^2]。 #### 流程图描述 以下是基于上述组件构建的CNN-BiGRU-Transformer预测模型的整体流程: ```plaintext 输入数据 -> CNN特征提取 -> BiGRU时序建模 -> Transformer全局依赖学习 -> 输出预测结果 ``` 具体步骤如下: - 输入原始数据经过预处理后送入CNN部分; - 提取的空间特征被传递至BiGRU以建立时间上的关联性; - 最终由Transformer完成高层次语义表示,并生成最终输出[^4]。 #### 架构图示例 下面提供了一个简化版的Markdown代码块来展示此架构图布局方式: ```mermaid graph TD; A[输入数据] --> B[CNN特征提取]; B --> C[BiGRU时序建模]; C --> D[Transformer全局依赖学习]; D --> E[输出预测结果]; ``` > 注:为了实际绘制图形,建议使用Mermaid.js或其他支持矢量图形编辑工具如PowerPoint、Lucidchart等按照以上逻辑创建可视化图表。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值