tensor2tensor报错'Tensor' object is not iterable.

使用tensor2tensor建立decoder报错:

使用tensor2tensor建立decoder报错:
```python
            # Build RNN cell
            decoder_cell = tf.nn.rnn_cell.LSTMCell(num_units=self.args.rnn_dim)
            # Helper
            decoder_inputs_embedded = tf.nn.embedding_lookup(self.embeddings,
                                                             self.gen_outputs)  # gen_outputs:以SOS为开始的decoder输入
            projection_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=self.initializer)
            helper = tf.contrib.seq2seq.TrainingHelper(
                decoder_inputs_embedded, self.gen_len, time_major=False) #这边self.gen_len要不要加一还不知道??
            # Decoder
            decoder = tf.contrib.seq2seq.BasicDecoder(
                decoder_cell, helper, initial_state=state,
                output_layer=projection_layer)
            # Dynamic decoding
            decoder_outputs, _,_ = tf.contrib.seq2seq.dynamic_decode(decoder,
                                                           impute_finished=True,
                                                           maximum_iterations=max_review_length)                               

  报错显示:

Traceback (most recent call last):
  File "/data1/home/qlj/pycharm27/CAML/train_CAML.py", line 771, in <module>
    exp = CFExperiment(inject_params=None)
  File "/data1/home/qlj/pycharm27/CAML/train_CAML.py", line 105, in __init__
    num_user=self.num_users)
  File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 89, in __init__
    self.build_graph()
  File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 1156, in build_graph
    self.gen_loss, self.gen_acc, self.key_word_loss = self._gen_review(q1_output, q2_output, r_input)
  File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 720, in _gen_review
    maximum_iterations=max_review_length)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
    swap_memory=swap_memory)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2640, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2590, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body
    decoder_finished) = decoder.step(time, inputs, state)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 138, in step
    cell_outputs, cell_state = self._cell(inputs, state)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 575, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 591, in call
    (c_prev, m_prev) = state
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

原因:

因为state的shape为(batchsize,dim)

而decoder_cell这边设置的是LSTMcell,则初始化的state大小应该为(2,batchsize,dim),为什么是2呢,因为包含了LSTM中的H和C

解决办法:

将decoder_cell设置为GRUcell

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值