一、学习单步的RNN:RNNCell
如果要学习TensorFlow中的RNN,第一站应该就是去了解“RNNCell”,它是TensorFlow中实现RNN的基本单元,每个RNNCell都有一个call方法,使用方式是:(output, next_state) = call(input, state)。
借助图片来说可能更容易理解。假设我们有一个初始状态h0,还有输入x1,调用call(x1, h0)后就可以得到(output1, h1):
再调用一次call(x2, h1)就可以得到(output2, h2):
也就是说,每调用一次RNNCell的call方法,就相当于在时间上“推进了一步”,这就是RNNCell的基本功能。
在代码实现上,RNNCell只是一个抽象类,我们用的时候都是用的它的两个子类BasicRNNCell和BasicLSTMCell。顾名思义,前者是RNN的基础类,后者是LSTM的基础类。这里推荐大家阅读其源码实现(地址:http://t.cn/RNJrfMl),一开始并不需要全部看一遍,只需要看下RNNCell、BasicRNNCell、BasicLSTMCell这三个类的注释部分,应该就可以理解它们的功能了。
除了call方法外,对于RNNCell,还有两个类属性比较重要:
-
state_size
-
output_size
前者是隐层的大小,后者是输出的大小。比如我们通常是将一个batch送入模型计算,设输入数据的形状为(batch_size, input_size),那么计算时得到的隐层状态