把手写笔记搬上来,以后就用博客进行学习记录了,加油!
所谓Attention机制,如用RNN作为encoder来学习输入序列,那么encoder上所有节点(词)隐藏层输出就相当于输入序列的背景变量(或者叫做query向量),寻找query与输出序列(或者叫做key向量)之间的权重关系,得出encoder上每个节点词与所有输出节点词之间的权重系数,进而计算某个输出节点输出值时就用该输出节点对应的权重系数加权输入序列的各节点隐藏层的输出变量(此时叫做value向量),加权后的和与输出节点初始输出值相连,再经logsoftmax,即可求出该输出节点处概率最高的词
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import nltk
import jieba
import numpy as np
# In[2]:
#读入中英文数据,中文数据用jieba分词,英文数据用nitk分词
def load_data(file):
cn_data=[]
en_data=[]
with open(file,'r',encoding='utf8') as text:
lines = text.readlines()
#print(lines)
for line in lines:
line_en_cn = line.strip().split('\t')
en_tokenize = nltk.word_tokenize(line_en_cn[0].lower())
#print(en_tokenize)
cn_tokenize = [w for w in jieba.cut(line_en_cn[1])]
#print(cn_tokenize)
en_data.append(['BOS']+en_tokenize+['EOS'])
cn_data.append(['BOS']+cn_tokenize+['EOS'])
#print(en_data,cn_data)
return en_data, cn_data
# In[3]:
train_en, train_cn = load_data('D:\\Seq2seq\\en-cn\\train.txt')
#print(train_en[0],train_cn[0])
test_en, test_cn = load_data('D:\\Seq2seq\\en-cn\\test.txt')
dev_en, dev_cn = load_data('D:\\Seq2seq\\en-cn\\dev.txt')
# In[168]:
#构建词汇表
from collections import Counter
UNK_idx=1
PAD_idx=0
def word_dict(sentences_list,maxwords=50000):
allwords_list =[w for s in sentences_list for w in s]
word_counter = Counter(allwords_list)
max_word_counter = word_counter.most_common(maxwords)
#print(max_word_counter)
word_dict = {word[0]:index+2 for index ,word in enumerate(max_word_counter)}
word_dict['UNK'] = UNK_idx
word_dict['PAD'] = PAD_idx
total_words = len(word_dict.keys())
#print(total_words)
return word_dict, total_words
# In[169]:
train_en_dict, train_en_total = word_dict(train_en)
train_cn_dict, train_cn_total = word_dict(train_cn)
test_en_dict, test_en_total = word_dict(test_en)
test_cn_dict, test_cn_total = word_dict(test_cn)
dev_en_dict, dev_en_total = word_dict(dev_en)
dev_cn_dict, dev_cn_total = word_dict(dev_cn)
# In[170]:
#将中英文词汇序列转变为数字序列
def word_index(word_dict,sentences_list):
length_sentences_list = [len(s) for s in sentences_list]
max_length = max([len(s) for s in sentences_list])
sentences_pad_list = [s+['PAD']*(max_length-len(s)) for s in sentences_list if len(s)< max_length]
length_pad_list = [len(s) for s in sentences_pad_list]
word_index_list = [[word_dict[w] if w in word_dict.keys() else word_dict['UNK'] for w in s]for s in sentences_list]
word_padindex_list = [[word_dict[w] if w in word_dict.keys() else word_dict['UNK'] for w in s]for s in sentences_pad_list]
#print(word_index_list)
return length_sentences_list, length_pad_list, word_index_list, word_padindex_list
# In[171]:
train_en_length,train_en_padlength,train_en_index,train_en_padindex=word_index(train_en_dict,train_en)
train_cn_length,train_cn_padlength,train_cn_index,train_cn_padindex=word_index(train_cn_dict,train_cn)
test_en_length,test_en_padlength,test_en_index,test_en_padindex=word_index(test_en_dict,test_en)
test_cn_length,test_cn_padlength,test_cn_index,test_cn_padindex=word_index(test_cn_dict,test_cn)
dev_en_length,dev_en_padlength,dev_en_index,dev_en_padindex=word_index(dev_en_dict,dev_en)
dev_cn_le