from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate
from tensorflow.keras.layers import Layer
class Attention(Layer):
def __init__(self, units=128, **kwargs):
self.units = units
super().__init__(**kwargs)
def __call__(self, inputs):
"""
Many-to-one attention mechanism for Keras.
@param inputs: 3D tensor with shape (batch_size, time_steps, input_dim).
@return: 2D tensor with shape (batch_size, 128)
@author: felixhao28, philipperemy.
"""
hidden_states = inputs
hidden_size = int(hidden_states.shape[2])
score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states)
h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states)
score = Dot(axes=[1, 2], name='attention_score')([h_t, score_first_part])
attention_weights = Activation('softmax', name='attention_weight')(score)
context_vector = Dot(axes=[1, 1], name='context_vector')([hidden_states, attention_weights])
pre_activation = Concatenate(name='attention_output')([context_vector, h_t])
attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')(pre_activation)
return attention_vector
def get_config(self):
return {'units': self.units}
@classmethod
def from_config(cls, config):
return cls(**config)