Tensorflow - dynamic_rnn 学习

本文详细介绍了如何使用TensorFlow构建RNN网络,包括LSTM与GRU单元的定义,以及如何通过dynamic_rnn函数搭建多层网络。特别强调了避免维度错误的技巧,适合深度学习初学者及实践者参考。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

函数目的就是为了构建一个RNN网络,前面少不了定义cell的类型,如LSTM与GRU等

API 里面 dynamic_rnn 函数的参数如下:

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

例子:

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

 多层网络:

这个定义的方法值得学习一个,因为这样定义可以避免维度错误,主要原因是TensorFlow的版本问题,可以构建一个多层网络,而且embedding size 和 hidden size的维度可以不同。

# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                   inputs=data,
                                   dtype=tf.float32)

参数 :

cell : RNN单元的一个实例如GRU LSTM等

inputs:[batch_size,max_time] 这两个参数比如有,后面可以是嵌套的数据

sequence_length :可选,int32 或 int64

initial_state:可选,初始状态

其余的省略了 一般也不用

参考https://tensorflow.google.cn/api_docs/python/tf/nn/dynamic_rnn?hl=zh-cn

返回值

outouts [batch_size, max_time, cell.output_size] 这样的形状,output_size = hidden size

state 末态,形状为[batch_size, cell.state_size]

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值