更新显示格式
本人 tensorflow新手, 由于训练和测试过程中的batch_size不一致, 需要复用graph中的一些功能,遇到BasicRNNCell复用问题,经过不断调查, 反复测试,问题解决。
分享给大家 ,如有疑问欢迎提问指正
出现问题的核心code如下 , 不再贴出详细的代码
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128, input_size=None, activation=tf.nn.tanh)
outputs1, _ = tf.nn.rnn(rnn_cell, input1, dtype=tf.float32) #input1 shape=[4,125,1000] for training
outputs2, _ = tf.nn.rnn(rnn_cell, input2, dtype=tf.float32) #input1 shape=[4,1,1000] for testing
以上code错误信息如下:
ValueError: Variable RNN/BasicRNNCell/Linear/Matrix already exists, disallowed. Did you mean to set reuse=True in VarScope?
调查得知错误是因为在BasicRNNCell 使用到了get_variable操作,第一次的时候创建新的variable,第二次调用的时候检测到命名冲突,报错。
程序目的是为了使training和testing使用同样的rnn_cell,修改后的code如下:
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128, input_size=None, activation=tf.nn.tanh)
with tf.variable_scope("rcnn", reuse=None):
outputs1, _ = tf.nn.rnn(rnn_cell, input1, dtype=tf.float32) #input1 shape=[4,125,1000] for training
with tf.variable_scope("rcnn", reuse=True):
outputs2, _ = tf.nn.rnn(rnn_cell, input2, dtype=tf.float32) #input1 shape=[4,1,1000] for testing
据我所知在使用如下Cell时同样会出现同样的错误
tf.contrib.rnn.MultiRNNCell
tf.nn.rnn_cell.BasicRNNCell
tf.contrib.rnn.BasicLSTMCell
Done