TensorFlow LSTM Shape变换

本文详细介绍了如何使用TensorFlow构建LSTM网络,并通过打印输出来检查网络各层的shape,确保网络正确构建。以ptbmodel中的build_rnn_graph_lstm代码为例,展示了如何定义网络配置、创建LSTM单元、构建多层RNN以及初始化状态。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow的图对应的shape可以先用print出来debug,拿ptb model里面的build_rnn_graph_lstm代码为例:


 
import numpy as np
import tensorflow as tf


class SmallConfig(object):
  """Small config."""
  init_scale = 0.1
  learning_rate = 1.0
  max_grad_norm = 5
  num_layers = 2
  num_steps = 2
  hidden_size = 200
  max_epoch = 4
  max_max_epoch = 13
  keep_prob = 1.0
  lr_decay = 0.5
  batch_size = 3
  vocab_size = 10000

def get_lstm_cell(config, is_training):
    return tf.contrib.rnn.BasicLSTMCell(config.hidden_size, forget_bias=0.0, state_is_tuple=True, reuse=not is_training)

def build_rnn_graph_lstm(inputs, config, is_training):
    """Build the inference graph using canonical LSTM cells."""

    # Slightly better results can be obtained with forget gate biases
    # initialized to 1 but the hyperparameters of the model would need to be
    # different than reported in the paper.
    def make_cell():
        cell = get_lstm_cell(config, is_training)
        if is_training and config.keep_prob < 1:
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, output_keep_prob=config.keep_prob)
        return cell

    cell = tf.contrib.rnn.MultiRNNCell(
        [make_cell() for _ in range(config.num_layers)], state_is_tuple=True)

    initial_state = cell.zero_state(config.batch_size, tf.float32)
    state = initial_state
    # Simplified version of tf.nn.static_rnn().
    # This builds an unrolled LSTM for tutorial purposes only.
    # In general, use tf.nn.static_rnn() or tf.nn.static_state_saving_rnn().
    #
    # The alternative version of the code below is:
    #
    # inputs = tf.unstack(inputs, num=config.num_steps, axis=1)
    # outputs, state = tf.nn.static_rnn(cell, inputs,
    #                                   initial_state=initial_state)
    outputs = []
    with tf.variable_scope("RNN"):
        for time_step in range(config.num_steps):
            if time_step > 0: tf.get_variable_scope().reuse_variables()
            (cell_output, state) = cell(inputs[:, time_step, :], state)
            print("cell_output = ", cell_output)
            print("cell_state = ", state)
            outputs.append(cell_output)
    outputs = tf.reshape(tf.concat(outputs, 1), [-1, config.hidden_size])
    return outputs, state

config = SmallConfig()
#inputs = tf.get_variable("inputs", initializer=[[[1.],[2.]], [[3.],[4.]], [[5.],[6.]]])
inputs = tf.placeholder(dtype = tf.float32, shape = [3, 2, 1])
print(inputs)
output, state = build_rnn_graph_lstm(inputs = inputs, config=config, is_training = True)
print("output =", output)
print("state =", state)
import numpy as np
import tensorflow as tf


class SmallConfig(object):
  """Small config."""
  init_scale = 0.1
  learning_rate = 1.0
  max_grad_norm = 5
  num_layers = 2
  num_steps = 2
  hidden_size = 200
  max_epoch = 4
  max_max_epoch = 13
  keep_prob = 1.0
  lr_decay = 0.5
  batch_size = 3
  vocab_size = 10000

def get_lstm_cell(config, is_training):
    return tf.contrib.rnn.BasicLSTMCell(config.hidden_size, forget_bias=0.0, state_is_tuple=True, reuse=not is_training)

def build_rnn_graph_lstm(inputs, config, is_training):
    """Build the inference graph using canonical LSTM cells."""

    # Slightly better results can be obtained with forget gate biases
    # initialized to 1 but the hyperparameters of the model would need to be
    # different than reported in the paper.
    def make_cell():
        cell = get_lstm_cell(config, is_training)
        if is_training and config.keep_prob < 1:
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, output_keep_prob=config.keep_prob)
        return cell

    cell = tf.contrib.rnn.MultiRNNCell(
        [make_cell() for _ in range(config.num_layers)], state_is_tuple=True)

    initial_state = cell.zero_state(config.batch_size, tf.float32)
    state = initial_state
    # Simplified version of tf.nn.static_rnn().
    # This builds an unrolled LSTM for tutorial purposes only.
    # In general, use tf.nn.static_rnn() or tf.nn.static_state_saving_rnn().
    #
    # The alternative version of the code below is:
    #
    # inputs = tf.unstack(inputs, num=config.num_steps, axis=1)
    # outputs, state = tf.nn.static_rnn(cell, inputs,
    #                                   initial_state=initial_state)
    outputs = []
    with tf.variable_scope("RNN"):
        for time_step in range(config.num_steps):
            if time_step > 0: tf.get_variable_scope().reuse_variables()
            (cell_output, state) = cell(inputs[:, time_step, :], state)
            print("cell_output = ", cell_output)
            print("cell_state = ", state)
            outputs.append(cell_output)
    outputs = tf.reshape(tf.concat(outputs, 1), [-1, config.hidden_size])
    return outputs, state

config = SmallConfig()
#inputs = tf.get_variable("inputs", initializer=[[[1.],[2.]], [[3.],[4.]], [[5.],[6.]]])
inputs = tf.placeholder(dtype = tf.float32, shape = [3, 2, 1])
print(inputs)
output, state = build_rnn_graph_lstm(inputs = inputs, config=config, is_training = True)
print("output =", output)
print("state =", state)

 

对应的输出为:

 


 
Tensor("Placeholder:0", shape=(3, 2, 1), dtype=float32)
cell_output =  Tensor("RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/mul_2:0", shape=(3, 200), dtype=float32)
cell_state =  (LSTMStateTuple(c=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/add_1:0' shape=(3, 200) dtype=float32>, h=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/mul_2:0' shape=(3, 200) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/add_1:0' shape=(3, 200) dtype=float32>, h=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/mul_2:0' shape=(3, 200) dtype=float32>))
cell_output =  Tensor("RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/mul_5:0", shape=(3, 200), dtype=float32)
cell_state =  (LSTMStateTuple(c=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/add_3:0' shape=(3, 200) dtype=float32>, h=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/mul_5:0' shape=(3, 200) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/add_3:0' shape=(3, 200) dtype=float32>, h=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/mul_5:0' shape=(3, 200) dtype=float32>))
output = Tensor("Reshape:0", shape=(6, 200), dtype=float32)
state = (LSTMStateTuple(c=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/add_3:0' shape=(3, 200) dtype=float32>, h=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_0/cell_0/basic_lstm_cell/mul_5:0' shape=(3, 200) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/add_3:0' shape=(3, 200) dtype=float32>, h=<tf.Tensor 'RNN/RNN/multi_rnn_cell/cell_1/cell_1/basic_lstm_cell/mul_5:0' shape=(3, 200) dtype=float32>))


但如果用注释中

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值