0、前言
在Seq2Seq模型的学习过程中,做过一个文本翻译任务案例,多轮训练后,效果还算能看
Transformer作为NLP领域的扛把子,对于此类任务的处理会更为强大,下面将以基于Transformer模型来重新处理此任务,看看效果如何
1、需求概述
现有一个《data.txt》文件,里面存放了很多组翻译对(即:英文句子 - 中文句子 的组合)
要求针对此《data.txt》文件,使用Seq2Seq模型构建一个翻译系统,并验证翻译效果
2、需求分析
这是一个典型的翻译任务,要求系统在用户输入英文句子之后,输出与之对应的中文句子,可以用Transformer模型来实现
具体来说,至少应该考虑以下几个要点:
-
1、分词器:自定义一个分词器,用于根据输入的语料构建字典
-
输入的句子(src --> source,英文句子)构建两个字典
-
src_token2idx字典的格式为:{英文词:id}
-
src_idx2token字典的格式为:{id:英文词}
-
-
输出的句子(tgt --> target,中文句子)构建两个字典
-
tgt_token2idx字典的格式为:{中文词:id}
-
tgt_idx2token字典的格式为:{id:中文词}
-
-
-
2、数据打包工具:自定义一个数据集和数据的批处理函数,用于将《data.txt》文件中的内容打包成模型处理所需的数据格式
-
3、Transformer模型:包含编码器和解码器方法
-
4、模型训练和推理:自定义模型的训练和推理方法
3、代码实现
3.1 导包
import os
import joblib
import copy
import math
import pandas as pd
import jieba
import opencc
import random
import time
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.functional import log_softmax
from torch.optim.lr_scheduler import LambdaLR
from sklearn.model_selection import train_test_split
3.2 构建分词器
class Tokenizer(object):
"""
自定义一个分词器,用于根据输入的语料构建字典:
1、输入的句子(src --> source,英文句子)构建两个字典
src_token2idx字典的格式为:{英文词:id}
src_idx2token字典的格式为:{id:英文词}
2、输出的句子(tgt --> target,中文句子)构建两个字典
tgt_token2idx字典的格式为:{中文词:id}
tgt_idx2token字典的格式为:{id:中文词}
"""
def __init__(self, data_file, saved_dict):
"""
初始化
"""
# 定义语料文件的路径
self.data_file = data_file
# 定义字典存储的路径
self.saved_dict = saved_dict
# 输入侧 src --> source
self.src_token2idx = None
self.src_idx2token = None
self.src_dict_len = None
self.src_embed_dim = 512
self.src_hidden_size = 512
# 输出侧 tgt --> target
self.tgt_token2idx = None
self.tgt_idx2token = None
self.tgt_dict_len = None
self.tgt_embed_dim = 512
self.tgt_hidden_size = 512
self.tgt_max_len = 100
# 构建字典
self._build_dict()
def _build_dict(self):
"""
构建字典
"""
# 1、如果四个字典都有值,则不需要浪费资源重复构建,跳出这个构建字典的_build_dict方法即可
if all([self.src_token2idx, self.src_idx2token, self.tgt_token2idx, self.tgt_idx2token]):
print("字典已经构建过了")
return
# 2、如果缓存里面有都通过joblib保存的字典文件,则也不需要浪费资源重复构建,直接从字典文件中获取字典,再跳出这个构建字典的_build_dict方法即可
elif os.path.exists(self.saved_dict):
print("从缓存中读取字典")
self.load()
print("读取缓存字典成功")
return
# 3、如果上面两个条件都不满足,则开始从零构建字典
# 3.1 构建标记元素集
# <UNK>:未知标记,用于表示在词汇表中未出现的词,当模型遇到一个它在训练数据中未曾见过的词时,会用 <UNK> 来代替
# <PAD>:填充标记,用于将所有输入序列填充到相同的长度,以便于能够使用批处理和固定大小的神经网络输入
# <SOS>:序列开始标记,用于指示序列生成的开始
# <EOS>:序列结束标记,用于指示序列生成的结束
# 输入侧不需要"<SOS>"和"<EOS>",输出侧需要"<SOS>"和"<EOS>"
src_tokens = {"<UNK>", "<PAD>"}
tgt_tokens = {"<UNK>", "<PAD>", "<SOS>", "<EOS>"}
# 3.2 从语料文件中读取数据
with open(file=self.data_file, mode="r", encoding="utf8") as f:
# 读取每一行内容
for line in tqdm(f.readlines()):
# 如果内容不为空,则执行下面代码
if line:
# 并将英文句子和中文句子通过中间的制表符分开
src_sentence, tgt_sentence = line.strip().split("\t")
# 分别调用split_src和split_tgt方法,对src_sentence和tgt_sentence进行 【句子-->词】 的切分处理
src_sentence_tokens = self.split_src(src_sentence)
tgt_sentence_tokens = self.split_tgt(tgt_sentence)
# 通过union方法取并集,来得到src_tokens和tgt_tokens
src_tokens = src_tokens.union(set(src_sentence_tokens))
tgt_tokens = tgt_tokens.union(set(tgt_sentence_tokens))
# 3.3 构建src的字典,包括src_token2idx和src_idx2token,并获取对应的字典长度
self.src_token2idx = {token: idx for idx, token in enumerate(src_tokens)}
self.src_idx2token = {idx: token for token, idx in self.src_token2idx.items()}
self.src_dict_len = len(self.src_token2idx)
# 3.4 构建tgt的字典,包括tgt_token2idx和tgt_idx2token,并获取对应的字典长度
self.tgt_token2idx = {token: idx for idx, token in enumerate(tgt_tokens)}
self.tgt_idx2token = {idx: token for token, idx in self.tgt_token2idx.items()}
self.tgt_dict_len = len(self.tgt_token2idx)
# 3.5 将上面构建好的四个字典都存放至缓存文件中,方便后面读取
self.save()
print("保存字典成功")
def split_src(self, sentence):
"""
英文句子分词
"""
# 1、去除首尾空格
sentence = sentence.strip()
# 2、将句子进行进行 转小写-->分词-->去除空词和'-->存列表 的操作
tokens = [token for token in jieba.lcut(sentence.lower()) if token not in ("", " ", "'")]
# 3、返回处理后的列表
return tokens
def split_tgt(self, sentence):
"""
中文句子分词
"""
# 1、实例化opencc工具,并设置繁体转简体模式
# t2s,即:Traditional Chinese to Simplified Chinese,表示将繁体句子转换为简体句子
# s2t,即:Simplified Chinese to Traditional Chinese,表示将简体句子转换为繁体句子
converter = opencc.OpenCC(config="t2s")
# 2、进行句子转换
sentence = converter.convert(text=sentence)
# 3、将句子进行进行 分词-->去除空词-->存列表 的操作
tokens = [token for token in jieba.lcut(sentence) if token not in ["", " "]]
# 4、返回处理后的列表
return tokens
def encode_src(self, src_sentence, src_sentence_len):
"""
对src进行编码:把英文句子分词后变成 id
1、按本批次的最大长度来填充,没有见过的词置为"<UNK>",长度不够的用多个"<PAD>"填充
2、src不用加"<SOS>和"<EOS>"
"""
# 1、"<UNK>"转换
src_idx = [self.src_token2idx.get(token, self.src_token2idx.get("<UNK>")) for token in src_sentence]
# 2、"<PAD>"填充
src_idx = (src_idx + [self.src_token2idx.get("<PAD>")] * src_sentence_len)[:src_sentence_len]
# 3、返回处理后的id值
return src_idx
def encode_tgt(self, tgt_sentence, tgt_sentence_len):
"""
对tgt进行编码:把中文句子分词后变成 id
1、按本批次的最大长度来填充,没有见过的词置为"<UNK>",长度不够的用多个"<PAD>"填充
2、tgt需要在首尾分别加"<SOS>和"<EOS>"
"""
# 1、在首尾分别加"<SOS>和"<EOS>"
tgt_sentence = ["<SOS>"] + tgt_sentence + ["<EOS>"]
# 2、因为在首尾分别加了"<SOS>和"<EOS>",共两个词,所以最大长度要加2
tgt_sentence_len += 2
# 3、"<UNK>"转换
tgt_idx = [self.tgt_token2idx.get(token, self.tgt_token2idx.get("<UNK>")) for token in tgt_sentence]
# 4、"<PAD>"填充
tgt_idx = (tgt_idx + [self.tgt_token2idx.get("<PAD>")] * tgt_sentence_len)[:tgt_sentence_len]
# 5、返回处理后的id值
return tgt_idx
def decode_tgt(self, tgt_ids):
"""
对src进行解码:把生成的id序列转换成中文token
输入:[6360, 7925, 8187, 7618, 1653, 4509]
输出:['我', '爱', '北京', '<UNK>']
"""
preds = []
# [batch_size, seq_len]
for temp_tgt in tgt_ids:
# 到EOS则代表结束,需要退出
if temp_tgt == self.tgt_token2idx.get("<EOS>"):
break
# 将不等于"<SOS>"、"<PAD>"的id获取出来,并输出id对应的token(中文词)
# 正如人们赚不到自己认知范围之外的钱一样,生成的序列也不会有找不到的id,所以不需要替换输出"<UNK>"
preds.append([self.tgt_idx2token.get(tgt_id) for tgt_id in temp_tgt if tgt_id not in (tokenizer.tgt_token2idx.get("<SOS>"),
tokenizer.tgt_token2idx.get("<PAD>"))])
return preds
# classmethod是一个装饰器,表示这个方法是类方法,可以通过类名直接调用,而不需要实例化对象(cls是这个方法的第一个参数,代表类本身)
# 主要用途:
# (1)将一个方法定义为类方法,可以清晰地表明它是一个与类相关联的操作,而不是与特定实例相关,这有助于代码的组织和可读性,使得其他开发者更容易理解这个方法的用途
# (2)如果一个方法需要在子类中被重写,那么使用类方法可以确保子类可以覆盖这个方法并提供特定的实现,同时仍然保持类方法的特性
@classmethod
def subsequent_mask(cls, size):
"""
生成屏蔽未来词的subsequent_mask矩阵:size=seq_len
"""
# 1、定义了一个形状为(1, size, size)的张量,size是序列的长度,这个张量用于后续创建对应的上三角矩阵
attn_shape = (1, size, size)
# 2、创建一个上三角矩阵,大小为attn_shape,其中对角线上的值为1,其余为0
# triu-->triangle upper(上三角矩阵),tril-->triangle lower(下三角矩阵)
# diagonal=1参数表示从对角线上方1个位置开始填充为1,对角线本身及以下的元素填充为0
# (如果是下三角矩阵,则diagonal=1表示从对角线下方1个位置开始填充为1,对角线本身及以上的元素填充为0)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
# 3、上面得到的是【上三角元素为1,对角线本身及以下的元素为0】的矩阵,现在将其与0值作比较,得到的是【上三角元素为False,对角线本身及以下的元素为True】的矩阵
# 比如:(由于False的遮挡效果,使得每一步只能看当前和之前的True)
# [[True, False, False],
# [True, True, False],
# [True, True, True]]
return subsequent_mask == 0
@classmethod
def make_std_mask(cls, tgt, pad):
"""
使用subsequent_mask屏蔽未来词
"""
# 1、去掉"<PAD>",并用unsqueeze改变形状
# unsqueeze(-2):让pad_mask的形状 [batch_size, seq_len] 变为 [batch_size, 1, seq_len]
pad_mask = (tgt != pad).unsqueeze(-2)
# 2、去掉未来词
# tgt.size(-1)是当前序列的长度,即seq_len;
# type_as(pad_mask.data)是将subsequent_mask转换为与pad_mask.data相同的数据类型,以确保两个张量可以进行按位与操作
tgt_mask = pad_mask & Tokenizer.subsequent_mask(tgt.size(-1)).type_as(pad_mask.data)
# 3、返回屏蔽未来词之后的矩阵
return tgt_mask
def save(self):
"""
定义保存字典的方法
"""
# 1、定义字典中的元素内容
state_dict = {
"src_token2idx": self.src_token2idx,
"src_idx2token": self.src_idx2token,
"src_dict_len": self.src_dict_len,
"tgt_token2idx": self.tgt_token2idx,
"tgt_idx2token": self.tgt_idx2token,
"tgt_dict_len": self.tgt_dict_len
}
# 2、保存文件到.cache目录下
if not os.path.exists(".cache"):
os.mkdir(path=".cache")
torch.save(obj=state_dict, f=self.saved_dict)
def load(self):
"""
加载字典
"""
if os.path.exists(path=self.saved_dict):
state_dict = torch.load(f=self.saved_dict, weights_only=True)
self.src_token2idx = state_dict.get("src_token2idx")
self.src_idx2token = state_dict.get("src_idx2token")
self.src_dict_len = state_dict.get("src_dict_len")
self.tgt_token2idx = state_dict.get("tgt_token2idx")
self.tgt_idx2token = state_dict.get("tgt_idx2token")
self.tgt_dict_len = state_dict.get("tgt_dict_len")
实例化分词器
tokenizer = Tokenizer(data_file="./data.txt", saved_dict="./.cache/dicts.bin")
print(tokenizer)
3.3 数据打包
class TransformerDatas