第28周:Transformer 实现文本分类 - Embedding版

目录

 前言

一、前期准备

1.1 环境安装

1.2 加载数据

二、数据预处理

2.1 构建词典

2.2 进行one-hot编码

2.3 自定义数据集类

2.4 定义填充函数

2.5 构建数据集

三、模型构建

3.1 定义位置编码器

3.2 定义Transformer模型

3.3 定义训练函数

3.4 定义测试函数

四、训练模型

4.1 模型训练

4.2 模型评估

五、模型调优

总结


 前言

说在前面

1)本周任务:在上周的代码基础上,将嵌入方式改为Embedding嵌入

2)运行环境:Python3.8、Pycharm2020、torch1.12.1+cu113


一、前期准备

1.1 环境安装

本文是基于Pytorch框架实现的文本分类

代码如下:

#一、准备工作
#1.1 环境安装
import torch,torchvision
print(torch.__version__)
print(torchvision.__version__)
import torch.nn as nn
from torchvision import transforms, datasets
import os, PIL,pathlib,warnings
import pandas as pd

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

打印输出:

2.0.0+cu118
0.15.1+cu118
cuda

1.2 加载数据

代码如下:


#1.2 加载数据
#加载自定义中文数据
train_data = pd.read_csv('train.csv', sep='\t', header=None)
print(train_data.head())
label_name = list(set(train_data[1].values[:]))
print('label name:', label_name)

打印输出:

  0              1
0      还有双鸭山到淮阴的汽车票吗13号的   Travel-Query
1                从这里怎么回家   Travel-Query
2       随便播放一首专辑阁楼里的佛里的歌     Music-Play
3              给看一下墓王之王嘛  FilmTele-Play
4  我想看挑战两把s686打突变团竞的游戏视频     Video-Play
['Radio-Listen', 'TVProgram-Play', 'Video-Play', 'Travel-Query', 'Weather-Query', 'Music-Play', 'HomeAppliance-Control', 'Other', 'Calendar-Query', 'FilmTele-Play', 'Alarm-Update', 'Audio-Play']

二、数据预处理

2.1 构建词典

需要安装jieba分词库,安装语句pip install jieba

代码如下(示例):

#二、数据预处理
#2.1 构建词典
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
import jieba

#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])


text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

打印输出:

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\XiaoMa\AppData\Local\Temp\jieba.cache
Loading model cost 0.320 seconds.
Prefix dict has been built successfully.
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
11

2.2 进行one-hot编码

代码如下:

#from functools import partial

X = [text_pipeline(i) for i in train_data[0].values[:]]
y = [label_pipeline(i) for i in train_data[1].values[:]]

# 对便签 y 进行 one-hot 编码
numbers_array = np.array(y)             # 转换为 NumPy 数组
num_classes   = np.max(numbers_array)+1 # 计算类别数量
y = np.eye(num_classes)[numbers_array]  # 进行 one-hot 编码

2.3 自定义数据集类

代码如下:

#from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts  = texts
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

2.4 定义填充函数

代码如下:

import torch.nn.functional as F

max_len = max(len(i) for i in X)

def collate_batch(batch, max_len):
    texts, labels = zip(*batch)
    padded_texts = [F.pad(text, (0, max_len - len(text)), value=0) for text in texts]
    padded_texts = torch.stack(padded_texts)
    labels = torch.tensor(labels, dtype=torch.float)#.unsqueeze(1)
    return padded_texts.to(device), labels.to(device)

# 使用 partial 函数创建 collate_fn,传入参数
collate_fn = partial(collate_batch, max_len=max_len)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值