基本原理
Transformer是一种深度学习模型,它通过自注意力机制和位置编码,实现了对序列数据的处理。在机器翻译任务中,Transformer模型将输入的源语言文本序列作为输入,通过编码器和解码器两个阶段,生成目标语言的文本序列。
在编码阶段,Transformer模型将源语言文本序列中的每个单词都映射到一个向量表示,并通过自注意力机制计算出每个单词的权重。然后,通过位置编码将单词的位置信息编码为向量,与单词向量相加得到最终的表示。在解码阶段,模型将编码器输出的向量作为输入,通过解码器生成目标语言的文本序列。
这个系统可以分为以下几个关键部分:
1.词汇表(Vocab):
Vocab类用于创建词汇表,包含每个单词或子词的频率信息,并添加特殊标记,如(未知单词)、(填充标记)、(句子开始标记)和(句子结束标记)。
2.分词器(Tokenizer):
Tokenizer类用于将句子转换为单词或子词的标识符序列。
3.模型(Seq2SeqTransformer):
- Seq2SeqTransformer类继承自nn.Module,它定义了一个序列到序列的Transformer模型,包括编码器和解码器。
- 模型包含词嵌入层、位置编码、Transformer编码器层、Transformer解码器层和输出线性层。
- forward方法实现了模型的前向传播,包括编码器和解码器的处理。
- encode和decode方法分别实现了编码器和解码器的前向传播。
4.损失函数(Loss Function):
loss_fn函数用于计算模型的损失。
5.数据处理(Data Processing):
- data_process函数将文本数据转换为张量形式。
- generate_batch函数将一批数据转换为批次形式,并添加开始和结束标记。
- create_mask函数创建源语言和目标语言的掩码,包括注意力掩码和填充掩码。
6.训练和评估(Training and Evaluation):
- train_epoch函数用于在一个epoch内训练模型。
- evaluate函数用于评估模型性能。
7.翻译(Translation):
- greedy_decode函数使用贪婪解码策略生成翻译结果。
- translate函数将源语言句子翻译成目标语言。
做完前期了解工作 我们开始编码吧!
数据准备
JParaCrawlhttp://www.kecl.ntt.co.jp/icl/lirg/jparacrawl在这个链接中下载汉-日对照数据集
然后把下载下来的数据转换为pandas的数据结果Dataframe便于处理
df = pd.read_csv('./zh-ja/zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)
trainen = df[2].values.tolist()#[:10000]
trainja = df[3].values.tolist()#[:10000]
我们可以通过下面的抽样检查 判断Dataframe中是否有误
print(trainen[500])
print(trainja[500])
输出显示是无误的
因为日语不存在空格提供天然的分词,所以我们需要对日语句子进行分词操作
ja_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.ja.nopretok.model')
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
接下来我们需要构建词汇表
并且把句子转换为张量
- 使用分词器编码句子: 使用之前加载的 SentencePiece 分词器将日语句子转换为子词的标识符序列。
- 统计词频: 使用 Counter 对象统计每个子词的出现次数,这有助于构建词汇表。
- 构建词汇表: 使用 TorchText 的 Vocab 类,根据子词的频率信息创建词汇表对象。
- 特殊标记: 词汇表会自动添加一些特殊标记,例如:
<unk>
:表示未知单词<pad>
:用于填充序列到相同长度<bos>
:表示句子的开始<eos>
:表示句子的结束
- 转换句子为张量: 使用词汇表将子词的标识符转换为整数索引,并将索引转换为 PyTorch 张量。 这样,模型就可以处理这些张量数据了。
# 定义一个构建词汇表的函数
def build_vocab(sentences, tokenizer):
"""
根据给定的句子列表和分词器构建词汇表。
参数:
sentences -- 列表,包含多个句子字符串。
tokenizer -- 一个分词器,能够将句子编码为单词或子词的标识符。
返回:
Vocab -- 一个词汇表对象,它包含句子中所有单词的频率信息,以及特殊标记。
这个函数首先创建一个Counter对象来计算每个单词或子词的出现次数。
然后,它遍历提供的句子列表,使用分词器对每个句子进行编码,并更新Counter。
最后,它使用Counter创建一个Vocab对象,这个对象还会添加几个特殊标记,
例如 '<unk>' 表示未知单词,'<pad>' 表示填充,'<bos>' 表示句子的开始,'<eos>' 表示句子的结束。
"""
counter = Counter()
for sentence in sentences: