以下为原博客的部分摘录
1、概述
- AttentionMechanism: 所有attention机制的父类, 内部没有任何实现。
- _BaseAttentionMechanism: 继承自AttentionMechanism, 定义了attention机制的一些公共方法实现和属性。
- BahdanauAttention和LuongAttention:均继承自_BaseAttentionMechanism,分别实现了1.2节所述的两种attention机制。
- AttentionWrapper: 用于封装RNNCell的类,继承自RNNCell,所以被它封装后依然是一个RNNCell类,只不过是带了attention的功能。
- AttentionWrapperState:用来存放计算过程中的state,前面说了AttentionWrapper其实也是一个RNNCell,那么它也有隐藏态(hidden state)信息,AttentionWrapperState就是这个state。除了RNN cell state,其中还额外存放了一些信息。
代码中一些命名术语的意思:
- key & query:Attention的本质可以被描述为一个查询(query)与一系列(键key-值value)对一起映射成一个输出:将query和每个key进行相似度计算得到权重并进行归一化,将权重和相应的键值value进行加权求和得到最后的attention,这里key=value。简单理解就是,query相当于前面说的解码器的隐藏态 h′i ,而key就是编码器的隐藏态 hi。
- memory: 这个memory其实才是编码器的所有隐藏态,与前面的key区别就是key可能是memory经过处理(例如线性变换)后得到的。
- alignments: 计算得到的每步编码器隐藏态 h 的权重向量,即 [αi1,αi2,…,αiTx]。
2、源码剖析
2.1 _BaseAttentionMechanism
初始化方法如下所示:
def __init__(self,
query_layer,
memory,
probability_fn,
memory_sequence_length=None,
memory_layer=None,
check_inner_dims_defined=True,
score_mask_value=None,
name=None):
以下是参数的说明:
- query_layer: 一个tf.layers.Layer实例, query会首先经过这一层.
- memory: 解码时用到的所有上下文信息,可简单理解为编码器的所有隐藏态.
- probability_fn: 将score eij 计算成概率用的函数,默认使用softmax,还可以指定hardmax等函数.
- memory_sequence_length: 即memory变量的实际长度信息,类似dynamic_rnn中的sequence_length,维度为[batch],这会被用作mask来去除超过实际长度的无用信息;
- memory_layer: 类似query_layer,也是一个tf.layers.Layer实例(或None),memory会经过这一层然后得到keys。需要注意的是,(经过memory_layer处理后得到的)key应该和(经过query_layer处理得到的)query的维度相匹配.
- check_inner_dims_defined: bool型,是否检查memory除了最外面两维其他维度是否是有定义的。
- score_mask_value: 在使用probability_fn计算概率之前,对score预先进行mask使用的值,默认是负无穷。但这个只有在memory_sequence_length参数定义的时候有效。
2.2 BahdanauAttention
BahdanauAttention是继承自2.2节的_BaseAttentionMechanism,它的构造函数如下所示:
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
normalize=False,
probability_fn=None,
score_mask_value=None,
dtype=None,
name="BahdanauAttention"):
以下是参数的声明
- num_units: 在进行query和key计算时, 两者的维度可能并不是统一的,所以需要进行变换和统一,因此用了num_units来声明两个全连接Dense网络,用于统一二者的维度:特别注意: BahdanauAttention的num_units和LuongAttention是有区别的.
query_layer=layers_core.Dense(num_units, name="query_layer", use_bias=False, dtype=dtype)
memory_layer=layers_core.Dense(num_units, name="memory_layer", use_bias=False, dtype=dtype)
- normalize: 是否在计算分数score时实现标准化.
2.3 LuongAttention
LuongAttention同样是继承自2.2节的_BaseAttentionMechanism,它的构造函数如下所示:
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
scale=False,
probability_fn=None,
score_mask_value=None,
dtype=None,
name="LuongAttention"):
- scale: 代表是否对得到的分数进行scale操作.
- num_units: 特别注意:这里和BahdanauAttention不一致
memory_layer = layers_core.Dense(num_units, name="memory_layer")
这里只对memory做了变换,对query并没有,所以num_units必须等于query的深度.
3、附录
3.1. Bahdanau attention & Luong Attention原理
详见https://blog.youkuaiyun.com/sinat_34072381/article/details/106728056
3.2. 参考
源码详细参考博客https://tangshusen.me/2019/03/09/tf-attention/
原理详细参考博https://blog.youkuaiyun.com/sinat_34072381/article/details/106728056