import tensorflow as tf
'''
LSTM中,每一个细胞单元的state = (c_state,hidden_state),output就是最后一个词语细胞的state中的hidden_state
'''
embedding_units = 256
units = 1024
input_vocab_size = len(input_tokenizer.word_index) +1
output_vocab_size = len(output_tokenizer.word_index) +1
# encode
class Encoder(tf.keras.Model):
def __init__(self,vocab_size,embedding_units,encoding_units,batch_size):
super(Encoder,self).__init__()
self.batch_size = batch_size
self.encoding_units = encoding_units
self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_units)
self.gru = tf.keras.layers.GRU(self.encoding_units,return_sequences=True,
return_state=True,recurrent_initializer='glorot_uniform')
def call(self,x,hidden):
# 输入,获取embedding
x = self.embedding(x)
# gru
output,state = self.gru(x,initial_state=hidden)
return output,state
def initialize_hidden_state(self):
return tf.zeros([self.batch_size,self.encoding_units])
# attention机制
class BahdanauAttention(tf.keras.Model):
def __init__(self,units):
super(BahdanauAttention,self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self,decoder_hidden,encoder_outputs):
# decoder_hidden.shape:[batch_size,units]
# encoder_outputs.shape:[batch_size,length,units]
decoder_hidden_with_time_axis = tf.expand_dims(decoder_hidden,1)
# before V:(batch_size,length,units)
# after V:(batch_size,length,1)
score = self.V(tf.nn.tanh(self.W1(encoder_outputs) + self.W2(decoder_hidden_with_time_axis)))
# shape:(batch_size,length,1)
attention_weights = tf.nn.softmax(score,axis=1)
# shape:(batch_size,length,units)
context_vector = attention_weights * encoder_outputs
# shape:(batch_size,units)
context_vector = tf.reduce_sum(context_vector,axis=1)
return context_vector,attention_weights
attention_model = BahdanauAttention(units=10)
attention_result,attention_weights = attention_model(sample_hidden,sample_output)
seq2seq(tf2.0版本)
最新推荐文章于 2020-12-05 10:51:17 发布