谷歌NMT源码阅读之vocab_utils

最近的工作任务是将谷歌NMT整套系统用自己的C代码库搭建好,因此需要对整个代码结构了解的非常清晰才可以,TF的不友好让我这段时间遇到了大大小小各种各样的问题,在工作临近结束时将一些对代码的理解记录下来,在本篇文章之后也会陆续更新其他代码,欢迎多多交流~~

本篇文章主要讲了nmt中的对于vocab的数据处理方法。其中主要包括了四个函数,分别为load_vocab、checkvocab、create_vocab和load_embed_txt。我们从代码的开头进行分析。

开头

# word level special token
UNK = "<unk>"
SOS = "<s>"
EOS = "</s>"
UNK_ID = 0

# char ids 0-255 come from utf-8 encoding bytes
# assign 256-300 to special chars
BOS_CHAR_ID = 256  # <begin sentence>
EOS_CHAR_ID = 257  # <end sentence>
BOW_CHAR_ID = 258  # <begin word>
EOW_CHAR_ID = 259  # <end word>
PAD_CHAR_ID = 260  # <padding>

DEFAULT_CHAR_MAXLEN = 50  # max number of chars for each word.

本段代码主要是声明一个unk、sos、eos的token以及unk的id,这些东西都是seq2seq模型中必不可少的。
声明了特殊字符如sow、sos一类的编码id。

_string_to_bytes

该函数主要作用为输入字符串和最大长度,将字符串序列转换为用对应id来表示的tensor。我自己写了一个队该函数的简单的测试代码如下:

def _string_to_bytes(text, max_length):

  byte_ids = tf.to_int32(tf.decode_raw(text, tf.uint8))
  byte_ids = byte_ids[:max_length - 2]
  padding = tf.fill([max_length - tf.shape(byte_ids)[0] - 2], PAD_CHAR_ID)
  byte_ids = tf.concat(
      [[BOW_CHAR_ID], byte_ids, [EOW_CHAR_ID], padding], axis=0)
  tf.logging.info(byte_ids)

  byte_ids = tf.reshape(byte_ids, [max_length])
  tf.logging.info(byte_ids.get_shape().as_list())
  return byte_ids + 1

if __name__=='__main__':
	text = 'I am scofyyy'
	maxlength = 50
	sess = tf.Session()
	print(sess.run(_string_to_bytes(text,maxlength)))
	sess.close()

最终输出的结果为如下:
长度本函数作用就是将输入的字符串转换为用对应的id编码的tensor,最后的一长串的261即为之前定义的PAD_CHAR_ID,也就是说在句子长度小于50时会自动填充。但是值得注意的是返回值使用的是byte_ids+1,这是因为之前定义了UNK_ID=0,那么其他的id就要往后顺延一位,因此会加一。

load_vocab

该函数将vocab_file中的单词读出并写入一个列表,最后返回了该列表即列表长度(即单词数)。


def load_vocab(vocab_file):
  vocab = []
  with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
    vocab_size = 0
    for word in f:
      vocab_size += 1
      vocab.append(word.strip())#strip作用为去掉单词后的换行符
  return vocab, vocab_size

在debug中内部如下图所示,因此需要去掉单词后的换行符。
在这里插入图片描述

check_vocab

该函数传入vocab_file和out_dir,作用为检测vocab_file中是否有三个特殊token,如果没有的话就依次加进去并放于最顶部并返回一个新的vocab_file。

create_vocab_tables

该函数作用为使用lookup_ops.index_table_from_file()函数为src_vocab_file和tgt_vocab_file创建单词索引表,即将单词和数字id一一对应,并返回一个查找表,查找表使用方法如下:
在这里插入图片描述

load_embed_txt

众所周知在将单词输入RNN网络时需要进行embedding操作,该函数将预先准备好的embed_file中的内容导入到一个python字典中。
该字典的键为单词,值为单词对应的embedding向量。最终返回embedding字典以及embedding后的向量的维度。embed_file的内容如下:
在这里插入图片描述
每一行为单词及其对应的向量。

总的来说关于vocab_utils的内容还是很简单很基础的,其主要目的就是介绍了一些处理机器翻译输入数据的有效方法。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值