最近的工作任务是将谷歌NMT整套系统用自己的C代码库搭建好,因此需要对整个代码结构了解的非常清晰才可以,TF的不友好让我这段时间遇到了大大小小各种各样的问题,在工作临近结束时将一些对代码的理解记录下来,在本篇文章之后也会陆续更新其他代码,欢迎多多交流~~
本篇文章主要讲了nmt中的对于vocab的数据处理方法。其中主要包括了四个函数,分别为load_vocab、checkvocab、create_vocab和load_embed_txt。我们从代码的开头进行分析。
开头
# word level special token
UNK = "<unk>"
SOS = "<s>"
EOS = "</s>"
UNK_ID = 0
# char ids 0-255 come from utf-8 encoding bytes
# assign 256-300 to special chars
BOS_CHAR_ID = 256 # <begin sentence>
EOS_CHAR_ID = 257 # <end sentence>
BOW_CHAR_ID = 258 # <begin word>
EOW_CHAR_ID = 259 # <end word>
PAD_CHAR_ID = 260 # <padding>
DEFAULT_CHAR_MAXLEN = 50 # max number of chars for each word.
本段代码主要是声明一个unk、sos、eos的token以及unk的id,这些东西都是seq2seq模型中必不可少的。
声明了特殊字符如sow、sos一类的编码id。
_string_to_bytes
该函数主要作用为输入字符串和最大长度,将字符串序列转换为用对应id来表示的tensor。我自己写了一个队该函数的简单的测试代码如下:
def _string_to_bytes(text, max_length):
byte_ids = tf.to_int32(tf.decode_raw(text, tf.uint8))
byte_ids = byte_ids[:max_length - 2]
padding = tf.fill([max_length - tf.shape(byte_ids)[0] - 2], PAD_CHAR_ID)
byte_ids = tf.concat(
[[BOW_CHAR_ID], byte_ids, [EOW_CHAR_ID], padding], axis=0)
tf.logging.info(byte_ids)
byte_ids = tf.reshape(byte_ids, [max_length])
tf.logging.info(byte_ids.get_shape().as_list())
return byte_ids + 1
if __name__=='__main__':
text = 'I am scofyyy'
maxlength = 50
sess = tf.Session()
print(sess.run(_string_to_bytes(text,maxlength)))
sess.close()
最终输出的结果为如下:
本函数作用就是将输入的字符串转换为用对应的id编码的tensor,最后的一长串的261即为之前定义的PAD_CHAR_ID,也就是说在句子长度小于50时会自动填充。但是值得注意的是返回值使用的是byte_ids+1,这是因为之前定义了UNK_ID=0,那么其他的id就要往后顺延一位,因此会加一。
load_vocab
该函数将vocab_file中的单词读出并写入一个列表,最后返回了该列表即列表长度(即单词数)。
def load_vocab(vocab_file):
vocab = []
with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
vocab_size = 0
for word in f:
vocab_size += 1
vocab.append(word.strip())#strip作用为去掉单词后的换行符
return vocab, vocab_size
在debug中内部如下图所示,因此需要去掉单词后的换行符。
check_vocab
该函数传入vocab_file和out_dir,作用为检测vocab_file中是否有三个特殊token,如果没有的话就依次加进去并放于最顶部并返回一个新的vocab_file。
create_vocab_tables
该函数作用为使用lookup_ops.index_table_from_file()函数为src_vocab_file和tgt_vocab_file创建单词索引表,即将单词和数字id一一对应,并返回一个查找表,查找表使用方法如下:
load_embed_txt
众所周知在将单词输入RNN网络时需要进行embedding操作,该函数将预先准备好的embed_file中的内容导入到一个python字典中。
该字典的键为单词,值为单词对应的embedding向量。最终返回embedding字典以及embedding后的向量的维度。embed_file的内容如下:
每一行为单词及其对应的向量。
总的来说关于vocab_utils的内容还是很简单很基础的,其主要目的就是介绍了一些处理机器翻译输入数据的有效方法。