目录
前言
🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
🍖 原作者:[K同学啊]
说在前面:理解Seq2Seq的代码,并跑通
一、前期准备工作
1.1 导入所需库
代码如下:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
1.2 搭建语言类
定义了两个常量 SOS_token 和 EOS_token,其分别代表序列的开始和结束。 Lang 类,用于方便对语料库进行操作:
- word2index 是一个字典,将单词映射到索引。
- word2count 是一个字典,记录单词出现的次数。
- index2word 是一个字典,将索引映射到单词。
- n_words 是单词的数量,初始值为 2,因为序列开始和结束的单词已经被添加
addSentence 方法:用于向 Lang 类中添加一个句子,它会调用 addWord 方法将句子中的每个单词添加到 Lang 类中。
addWord 方法:将单词添加到 word2index、word2count 和 index2word 字典中,并对 n_words 进行更新。如果单词已经存在于 word2index 中,则将 word2count 中对应的计数器加 1。
代码如下:
SOS_token = 0 #序列的开始
EOS_token = 1 #序列的结束
# 语言类,方便对语料库进行操作
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {} #将单词映射到索引
self.word2count = {} #记录单词出现的次数
self.index2word = {0: "SOS", 1: "EOS"} #将索引映射到单词
self.n_words= 2 #单词的数量,初始值为2 Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
1.3 文本处理函数
unicodeToAscii 函数:
- 使用了 Python 的 unicodedata 模块,通过 normalize 方法将字符串 s 转换为 Unicode 规范化形式 NFD。
- 使用条件判断语句过滤掉了 unicodedata.category© 为 ‘Mn’ 的字符。
- 剩下的字符通过join组成了一个新的字符串。
“Mn”(即“Nonspacing_Mark”)是表示“非间隔标记”的字符类别之一,“非间隔标记”是指那些不会独立显示的标记或符号,它们通常附加在其他字符上面以改变该字符的发音或外观。例如,重音符号(如“é”中的“´”)和分音符号(如“ā”中的“ˉ”)就属于“非间隔标记”
normalizeString 函数:
- 将字符串 s 转换为小写字母形式,并去除首尾空格,随后将字符串输入unicodeToAscii 函数。
- 通过正则表达式替换,将句子中的标点符号(‘.’、‘!’、‘?’)前添加一个空格。
- 将非字母符号替换为空格。
- 最后返回处理后的字符串 s
代码如下:
# 1.3 文本处理函数
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# 小写化,剔除标点与非字母符号
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s