第25周:seq2seq翻译实战-Pytorch复现(结合注意力机制)

目录

前言

一、前期准备工作

1.1 导入所需库

1.2 搭建语言类

1.3 文本处理函数

1.4 文件读取函数

二、Seq2Seq模型

2.1 编码器(Encoder)

2.2 解码器(Decoder)

三、训练

3.1 数据预处理

3.2 训练函数

四、训练与评估

五、可视化训练

5.1 LOSS图

5.2 可视化注意力

总结


前言

说在前面:理解Seq2Seq的代码,并跑通

加上如下图所示的注意力机制


一、前期准备工作

1.1 导入所需库

代码如下:

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

1.2 搭建语言类

定义了两个常量 SOS_token 和 EOS_token,其分别代表序列的开始和结束。 Lang 类,用于方便对语料库进行操作:

  • word2index 是一个字典,将单词映射到索引。
  • word2count 是一个字典,记录单词出现的次数。
  • index2word 是一个字典,将索引映射到单词。
  • n_words 是单词的数量,初始值为 2,因为序列开始和结束的单词已经被添加

addSentence 方法:用于向 Lang 类中添加一个句子,它会调用 addWord 方法将句子中的每个单词添加到 Lang 类中。
addWord 方法:将单词添加到 word2index、word2count 和 index2word 字典中,并对 n_words 进行更新。如果单词已经存在于 word2index 中,则将 word2count 中对应的计数器加 1。

代码如下:

SOS_token = 0     #序列的开始
EOS_token = 1     #序列的结束

# 语言类,方便对语料库进行操作
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}                       #将单词映射到索引
        self.word2count = {}                       #记录单词出现的次数
        self.index2word = {0: "SOS", 1: "EOS"}     #将索引映射到单词
        self.n_words= 2                            #单词的数量,初始值为2 Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

1.3 文本处理函数

unicodeToAscii 函数:

  • 使用了 Python 的 unicodedata 模块,通过 normalize 方法将字符串 s 转换为 Unicode 规范化形式 NFD。
  • 使用条件判断语句过滤掉了 unicodedata.category© 为 ‘Mn’ 的字符。
  • 剩下的字符通过join组成了一个新的字符串。

“Mn”(即“Nonspacing_Mark”)是表示“非间隔标记”的字符类别之一,“非间隔标记”是指那些不会独立显示的标记或符号,它们通常附加在其他字符上面以改变该字符的发音或外观。例如,重音符号(如“é”中的“´”)和分音符号(如“ā”中的“ˉ”)就属于“非间隔标记”

normalizeString 函数:

  • 将字符串 s 转换为小写字母形式,并去除首尾空格,随后将字符串输入unicodeToAscii 函数。
  • 通过正则表达式替换,将句子中的标点符号(‘.’、‘!’、‘?’)前添加一个空格。
  • 将非字母符号替换为空格。
  • 最后返回处理后的字符串 s

代码如下:

# 1.3 文本处理函数
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# 小写化,剔除标点与非字母符号
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

1.4 文件读取函数

       接受两个参数 lang1 和 lang2,分别表示要读取的语言。
       函数使用 Python 的 open 函数读取指定的文件,文件名格式为lang1-lang2.txt,以行为单位读取文件内容,并使用 strip 方法去掉每行末尾的换行符。接着,使用 split 方法将文本按照换行符分割成一个字符串列表 lines。
       对于列表 lines 中的每一行,使用 split 方法将其按照制表符分割成两个元素,分别表示 A 语言文本和 B 语言文本。对于每个元素,调用 normalizeString 函数进行预处理,并将处理后的 A 语言文本和 B 语言文本组成一个新的列表 pairs。
       参数 reverse 的值,创建 input_lang 和 output_lang 两个 Lang 类的实例,分别表示输入语言和输出语言。如果 reverse 为 True,则将 pairs 列表中的每个元素反转,并将 input_lang 和 output_lang 交换。最后,返回 input_lang、output_lang 和 pairs 三个值。
       其实举个例子可以方便理解,比如文件为eng-fra.txt,对应的lang1:eng、lang2:fra。我们按照行读取数据,随便抽一行:I see. Je comprends.,中间使用制表符 ‘\t’ 分割,读取会将这一行放入列表的同一行,随后使用normalizeString 函数进行处理,将处理后的I see.和Je comprends.组成一个新的列表 pairs。如果 reverse 为 False,则 input_lang 对象对应 lang1 表示的源语言,output_lang 对象对应 lang2 表示的目标语言。

代码如下:

#1.4 文件读取函数

def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")
    # 以行为
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值