前言
本系列主要主要是记录下Tensorflow在RNN实现这一块的相关代码,不做详细解释,主要是翻译加笔记。
RNNCell
在Tensorflow中,定义了一个RNNCell的抽象类,具体的所有不同类型的RNN Cell都是基于这个类的,所以就首先讲一下这个,下面是基本的代码:
class RNNCell(object):
def __call__(self, inputs, state, scope=None):
raise NotImplementedError("Abstract method")
@property
def state_size(self):
raise NotImplementedError("Abstract method")
@property
def output_size(self):
raise NotImplementedError("Abstract method")
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if nest.is_sequence(state_size):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
dtype=dtype)

这篇博客主要解析Tensorflow中RNNCell的实现,包括其作为RNN核心的结构和功能,特别是BasicRNNCell的源代码。RNNCell是一个包含状态并处理输入矩阵的对象,输出输出矩阵和可能的状态矩阵。文章详细介绍了RNNCell的__call__方法及其生成初始化状态的方法,并探讨了BasicRNNCell的简单实现。
最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



