一、Wrapper装饰器
1. Wrapper基本结构
以tf.nn.rnn_cell.DropoutWrapper为例:
@tf_export("nn.rnn_cell.DropoutWrapper") # 起了个名字.
class DropoutWrapper(RNNCell): # 这里看出继承了RNNCell
def __init__(self, cell, ....):
"""
仅涉及变量的初始化和assert.
"""
def __call__(self, inputs,state, scope=None): # 真正的功能
"""
功能实现.
"""
return output, new_state
def zero_state(self, batch_size, dtype): # 这个是继承自RNNCell的
"""
"""
@property # 以属性方式, 调用函数.
def state_size(self):
return self._cell.state_size
@property
def output_size(self):
return self._cell.output_size
主要就是以上几点,这是因为继承自RNNCell, 所以一些不符合自身的方法需要重写.
以下是RNNCell的源码结构
@tf_export("nn.rnn_cell.RNNCell")
class RNNCell(base_layer.Layer):
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
"""
变量的初始化, 以及调用父类方法.
"""
def __call__(self, inputs, state, scope=None):
"""
实现方法
"""
@property
def state_size(self):
raise NotImplementedError("Abstract method") # 抽象方法, 必须实现.
@property
def output_size(self):
raise NotImplementedError("Abstract method") # 抽象方法, 必须实现.
def get_initial_state(self, ...):
"""
"""
def zero_state(self, bath_size, dtype):
"""
"""
除了源码中必须实现的抽象方法, 其他皆可不管,但也可以覆盖重写。
2. AttentionWrapper基本结构
为了保证AttentionWrapper封装后还是RNNCell的实例,做了以下处理:定义AttentionWrapperState, 代替原有State, 原先RNNCell的状态只作为它的一个域,而它还有别的域,用来保存和注意力机制相关的东西,这样,封装后的单元仍然有类似公式yt, st = f(xt, st-1)的循环,改变的只是state,但不影响和别的接口结合。
AttentionWrapper的源码结构:
class AttentionWrapper(rnn_cell_impl.RNNCell):
def __init__(self, cell, attention_mechanism, ...):
"""
初始化
"""
@property
def output_size(self):
if self._output_attention:
return self._attention_layer_size
else:
return self._cell.output_size
@property
def state_size(self):
return AttentionWrapperState(...)
def zero_state(self, batch_size, dtype):
"""
初始化方法
"""
return AttentionWrapperState(...)
def call(self, inputs, state):
"""
1. 通过 `cell_input_fn`函数组合 `inputs` 和 上一个time step的 `attention`
2. 使用input 和 previous state调用 cell函数
3. 使用attention_mechanism计算cell的output
4. 计算alignments
5. 使用alignments和memory的点积计算context vector
6. 通过attention layer(一层dense)拼接cell output和context.
"""
AttentionWrapperState的源码结构:
class AttentionWrapperState(collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time", "alignments", "alignment_history", "attention_state"))):
def clone(self, **kwargs):
"""
实现.
"""
二、dynamic_rnn的两个输出.
2.1 问题描述
由于输入文本有padding, 这样喂入rnn网络中,最后得到的hidden state是哪部分的呢?
是最后padding的,还是之前最后一个有效部分的?
结果有什么影响吗?
2.2 代码验证
import tensorflow as tf
if __name__ == "__main__":
X = tf.get_variable(shape=[1, 8, 4], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(),
name="input")
seq_len = tf.constant([5])
basic_rnn = tf.nn.rnn_cell.GRUCell(4)
# 从这里可以看出.
# h 和 o 大多数时候等同.
outputs, hidden_state = tf.nn.dynamic_rnn(cell=basic_rnn, inputs=X,
sequence_length=seq_len, dtype=tf.float32,
swap_memory=True)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
o, h = sess.run([outputs, hidden_state])
print("o is :", o)
print("h is :", h)
# 结果:
o is : [[[ 0.01111094 -0.0157681 0.0695639 0.06153589]
[ 0.00111751 -0.15234253 0.10063156 0.05778597]
[-0.0316438 -0.0516398 0.17971487 0.07490367]
[-0.10333496 0.07261787 0.2896123 0.13996798]
[ 0.130892 0.00722314 0.1544136 -0.08568334]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
h is : [[ 0.130892 0.00722314 0.1544136 -0.08568334]]
结果可以看出,padding部分的output是0.
而最后的hidden_state和最后一个有效部分的output值相同.
因此可以暂时不用多做考虑.
参考
参考博客https://zhuanlan.zhihu.com/p/34393028
参考博客https://tangshusen.me/2019/03/09/tf-attention/ 高赞, 附源码解析.