Tensorflow RNN相关——高级

一、Wrapper装饰器

1. Wrapper基本结构

  以tf.nn.rnn_cell.DropoutWrapper为例:

@tf_export("nn.rnn_cell.DropoutWrapper")    	# 起了个名字. 
class DropoutWrapper(RNNCell):        			# 这里看出继承了RNNCell
	def __init__(self, cell, ....):
		"""
			仅涉及变量的初始化和assert. 
		""" 
		
	def __call__(self, inputs,state, scope=None):   # 真正的功能
		"""
			功能实现. 
		"""
		return output, new_state
	
	def zero_state(self, batch_size, dtype):    # 这个是继承自RNNCell的
		"""
		"""
		
	@property        # 以属性方式, 调用函数. 
	def state_size(self): 
		return self._cell.state_size 
	
	@property   
	def output_size(self):
		return self._cell.output_size

  主要就是以上几点,这是因为继承自RNNCell, 所以一些不符合自身的方法需要重写.

  以下是RNNCell的源码结构

@tf_export("nn.rnn_cell.RNNCell")
class RNNCell(base_layer.Layer):
	def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
		"""
			变量的初始化, 以及调用父类方法.
		"""
	
	def __call__(self, inputs, state, scope=None):
		"""
			实现方法
		"""
	
	@property
	def state_size(self):
		raise NotImplementedError("Abstract method")       # 抽象方法, 必须实现. 
	
	@property
	def output_size(self):
		raise NotImplementedError("Abstract method")       # 抽象方法, 必须实现. 
	
	def get_initial_state(self, ...):
		"""
		"""
	
	def zero_state(self, bath_size, dtype):
		"""
		
		"""

  除了源码中必须实现的抽象方法, 其他皆可不管,但也可以覆盖重写。

2. AttentionWrapper基本结构

  为了保证AttentionWrapper封装后还是RNNCell的实例,做了以下处理:定义AttentionWrapperState, 代替原有State, 原先RNNCell的状态只作为它的一个域,而它还有别的域,用来保存和注意力机制相关的东西,这样,封装后的单元仍然有类似公式yt, st = f(xt, st-1)的循环,改变的只是state,但不影响和别的接口结合。

AttentionWrapper的源码结构:
class AttentionWrapper(rnn_cell_impl.RNNCell):
	def __init__(self, cell, attention_mechanism, ...):
		"""
			初始化
		"""
	
	@property
	def output_size(self):
		if self._output_attention:
			return self._attention_layer_size
		else:
			return self._cell.output_size 
	
	@property
	def state_size(self):
		return AttentionWrapperState(...)
	
	def zero_state(self, batch_size, dtype):
		"""
			初始化方法
		"""
		return AttentionWrapperState(...)
	
	def call(self, inputs, state):
		"""
			1. 通过 `cell_input_fn`函数组合 `inputs` 和 上一个time step的 `attention`
			2. 使用input 和 previous state调用 cell函数 
			3. 使用attention_mechanism计算cell的output
			4. 计算alignments
			5. 使用alignments和memory的点积计算context vector
			6. 通过attention layer(一层dense)拼接cell output和context. 
		"""
AttentionWrapperState的源码结构:
class AttentionWrapperState(collections.namedtuple("AttentionWrapperState",
							("cell_state", "attention", "time", "alignments", "alignment_history", "attention_state"))):
	def clone(self, **kwargs):
		"""
			实现.
		"""

源码链接https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py

二、dynamic_rnn的两个输出.

2.1 问题描述

由于输入文本有padding, 这样喂入rnn网络中,最后得到的hidden state是哪部分的呢?
是最后padding的,还是之前最后一个有效部分的?
结果有什么影响吗?

2.2 代码验证

import tensorflow as tf

if __name__ == "__main__":
	X = tf.get_variable(shape=[1, 8, 4], dtype=tf.float32,
                        initializer=tf.contrib.layers.xavier_initializer(),
                        name="input")
    seq_len = tf.constant([5])

    basic_rnn = tf.nn.rnn_cell.GRUCell(4)

    # 从这里可以看出.
    # h 和 o 大多数时候等同.
    outputs, hidden_state = tf.nn.dynamic_rnn(cell=basic_rnn, inputs=X,
                                              sequence_length=seq_len, dtype=tf.float32,
                                              swap_memory=True)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        o, h = sess.run([outputs, hidden_state])
        print("o is :", o)
        print("h is :", h)


# 结果: 
o is : [[[ 0.01111094 -0.0157681   0.0695639   0.06153589]
  [ 0.00111751 -0.15234253  0.10063156  0.05778597]
  [-0.0316438  -0.0516398   0.17971487  0.07490367]
  [-0.10333496  0.07261787  0.2896123   0.13996798]
  [ 0.130892    0.00722314  0.1544136  -0.08568334]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]]
h is : [[ 0.130892    0.00722314  0.1544136  -0.08568334]]

结果可以看出,padding部分的output是0.
而最后的hidden_state和最后一个有效部分的output值相同.

因此可以暂时不用多做考虑.

参考

参考博客https://zhuanlan.zhihu.com/p/34393028
参考博客https://tangshusen.me/2019/03/09/tf-attention/ 高赞, 附源码解析.

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值