函数目的就是为了构建一个RNN网络,前面少不了定义cell的类型,如LSTM与GRU等
API 里面 dynamic_rnn 函数的参数如下:
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
例子:
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
多层网络:
这个定义的方法值得学习一个,因为这样定义可以避免维度错误,主要原因是TensorFlow的版本问题,可以构建一个多层网络,而且embedding size 和 hidden size的维度可以不同。
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
参数 :
cell : RNN单元的一个实例如GRU LSTM等
inputs:[batch_size,max_time] 这两个参数比如有,后面可以是嵌套的数据
sequence_length :可选,int32 或 int64
initial_state:可选,初始状态
其余的省略了 一般也不用
参考https://tensorflow.google.cn/api_docs/python/tf/nn/dynamic_rnn?hl=zh-cn
返回值
outouts [batch_size, max_time, cell.output_size] 这样的形状,output_size = hidden size
state 末态,形状为[batch_size, cell.state_size]