基于transformer的机器翻译实战

一、基本模块

1、位置嵌入(position embedding)

(1)为什么要引入位置嵌入?

        文本序列中的单词是有顺序的,一个单词在序列中所处的位置对我们理解其词义、上下文关系都十分重要,但是传统的词向量嵌入(word embedding)并不包含位置信息,所以专门引入位置向量。

(2)如何实现位置嵌入?

主要有两种方式:

  • 可学习位置嵌入:为每一个位置初始化一个位置嵌入向量,并且将位置嵌入向量作为模型参数,之后会训练过程中不断更新该向量。
  • 绝对位置嵌入:位置嵌入向量初始化之后就不再改变。一般基于三角函数式,又称Sinusoidal Position Encoding,公式如下: 

 PE_{k, 2i} = sin(\frac{k}{10000^{\frac{2i}{d_{model}}}}) \\ PE_{k, 2i+1} = cos(\frac{k}{10000^{\frac{2i+1}{d_{model}}}})

        

        分别通过sin和cos计算位置 k  的编码向量的第 2 i 和 2 i + 1个分量,d_{model}是位置向量的维度。

两种方式的比较:有论文实验显示,绝对位置嵌入和可学习位置嵌入最终的效果是类似的,但是可学习位置嵌入会引入额外的参数,增加训练开销,所以本项目使用基于三角函数式的绝对位置嵌入

2、掩码机制(mask)

(1)mask的作用是什么?什么情况下需要使用mask?

        作用是避免过拟合,如果不使用mask,会导致模型在训练时就能看到整个句子,从而导致训练准确度上升很快,但是验证准确度会先升后降。

        第一种情况是输入序列长度不一致,需要使用“pad”字符补全短序列,保证序列长度一致性。在计算注意力时,就会需要将“pad”字符掩去。

        第二种情况是为了保证训练效果,在训练时不能直接看到整个句子,而是只能看到当前所处位置及其之前位置的单词,所以可以使用三角型的mask矩阵。

(2)mask实现方式

         对于第一种情况,需要先确定在词表中“pad”的序号,不妨假设pad = 1,序列向量seq = [[1,2,3],[2,2,2],[1,0,0]],辅助矩阵p = \begin{bmatrix} 1 &1 &1 \\ 1 &1 & 1\\ 1&1 & 1 \end{bmatrix},这里的1是因为pad=1,然后比较seq和p,相等的位置置1,不相等的位置置0,得到mask矩阵:

mask = [[1,0,0],[0,0,0],[1,0,0]]

        本项目使用self-attention,会出现上述第二种情况。由于是self-attention,因此Q = K = V,假设:Q = K = V = \begin{bmatrix} s_1\\ s_2\\ s_3\\ s_4 \end{bmatrix}

        根据注意力计算公式,需要先计算QK^T:

        当我们遍历到第2个位置时,应该只能知道s_1s_2,而无法看到s_3s_4,所以理论上无法计算出s_2s_3^Ts_2s_4^T,因此要把这两个位置掩去,同理可以推出mask矩阵形式为:

mask = \begin{bmatrix} 1 & 0 & 0 & 0\\ 1 & 1&0 & 0\\ 1& 1& 1&0 \\ 1& 1 & 1&1 \end{bmatrix}

二、代码

1、目录架构
Machine_translation
    --data #存放数据集
        --eng-fra.txt #英语-法语数据集
    --save #保存模型参数
    --data_process.py #数据预处理
    --decoder.py #定义transformer解码器
    --encoder.py #定义transformer编码器
    --layer.py #定义transformer网络层
    --modules.py #实现位置嵌入、mask、词/索引转换等模块
    --optimizer.py #动态学习率
    --train.py #配置以及训练
    --transformer.py #搭建transformer模型
2、data_process.py:数据预处理

数据集下载:见文章顶部

数据标准化流程:转小写 ——> 转码 ——> 在标点符号前插入空格 ——> 剔除数字等非法字符 ——> 剔除多余空格

import unicodedata
import re
import pandas as pd
import torchtext
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split

class DataLoader:
    def __init__(self, data_iter):
        self.data_iter = data_iter
        self.length = len(data_iter)  # 一共有多少个batch?

    def __len__(self):
        return self.length

    def __iter__(self):
        # 注意,在此处调整text的shape为batch first
        for batch in self.data_iter:
            yield (torch.transpose(batch.src, 0, 1), torch.transpose(batch.targ, 0, 1))

# 将unicode字符串转化为ASCII码
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

# 标准化句子序列
def normalizeString(s):
    s = s.lower().strip() # 全部转小写
    s = unicodeToAscii(s)
    s = re.sub(r"([.!?])", r" \1", s)  # \1表示group(1)即第一个匹配到的 即匹配到'.'或者'!'或者'?'后,一律替换成'空格.'或者'空格!'或者'空格?'
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)  # 非字母以及非.!?的其他任何字符 一律被替换成空格
    s = re.sub(r'[\s]+', " ", s)  # 将出现的多个空格,都使用一个空格代替。例如:w='abc  aa   bb' 处理后:w='abc aa bb'
    return s

# 文件是英译法,我们实现的是法译英,所以进行了reverse,所以pair[1]是英语
def exchangepairs(pairs):
    # 过滤,并交换句子顺序,得到法英句子对(之前是英法句子对)
    return [[pair[1], pair[0]] for pair in pairs]

def get_dataset(pairs, src, targ):
    fields = [('src', src), ('targ', targ)]  # filed信息 fields dict[str, Field])
    examples = []  # list(Example)
    for fra, eng in tqdm(pairs): # 进度条
        # 创建Example时会调用field.preprocess方法
        examples.append(torchtext.legacy.data.Example.fromlist([fra, eng], fields))
    return examples, fields

def get_datapipe(opt, src, tar):
    data_df = pd.read_csv(opt.data_dir + 'eng-fra.txt',  # 数据格式:英语\t法语,注意我们的任务源语言是法语,目标语言是英语
    encoding='UTF-8', sep='\t', header=None,
    names=['eng', 'fra'], index_col=False)
    pairs = [[normalizeString(s) for s in line] for line in data_df.values]
    pairs = exchangepairs(pairs)
    train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=1234)
    
    ds_train = torchtext.legacy.data.Dataset(*get_dataset(train_pairs, src, tar))
    ds_val = torchtext.legacy.data.Dataset(*get_dataset(val_pairs, src, tar))
    
    train_iter, val_iter = torchtext.legacy.data.Iterator.splits(
        (ds_train, ds_val),
        sort_within_batch=True,
        sort_key=lambda x: len(x.src),
        batch_sizes=(opt.batch_size, opt.batch_size)
    )
 
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值