---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_553597/306262114.py in <module>
7 ## Training
8 # Epoch_list,Loss_list = model_train(batchsize,channel_SNR_db1,noise_init,nl_factor,eq_flag,norm_epsilon,earlystop_epoch)
----> 9 Epoch_list,Loss_list, Min_Distance_list = model_train(batchsize,channel_SNR_db1,noise_init,nl_factor,eq_flag,norm_epsilon,earlystop_epoch, min_distance_threshold=0.7,flags_schedule=[(1, 0), (0, 1), (1, 1)],iter_per_stage=50)
/tmp/ipykernel_553597/4102420687.py in model_train(batchsize, channel_SNR, noise_init, nl_factor, eq_flag, epsilon, earlystop_epoch, min_distance_threshold, flags_schedule, iter_per_stage)
58
59 (batch_loss, batch_loss_Eq, NGMI, GMI, entropy_S,
---> 60 p_s, norm_constellation, x, min_distance) = train_step(
61 channel_SNR, noise_tf, GS_flag_now, PS_flag_now, eq_flag, epsilon, min_distance_threshold
62 )
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
/tmp/__autograph_generated_file_jsnzuik.py in tf__train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon, min_distance_threshold)
39 batch_size = ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(p_s),), None, fscope)[0]
40 batch_indices = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).range, (ag__.ld(batch_size),), None, fscope)[:, ag__.ld(tf).newaxis, ag__.ld(tf).newaxis], [1, ag__.ld(M_int), ag__.ld(k)]), None, fscope)
---> 41 gather_indices = ag__.converted_call(ag__.ld(tf).stack, ([ag__.ld(batch_indices), ag__.converted_call(ag__.ld(tf).tile, (ag__.ld(topk_indices)[:, :, ag__.ld(tf).newaxis, :], [1, 1, ag__.ld(k), 1]), None, fscope)],), dict(axis=(- 1)), fscope)
42 neighbor_probs = ag__.converted_call(ag__.ld(tf).gather_nd, (ag__.ld(p_s), ag__.ld(gather_indices)), None, fscope)
43 neighbor_sum = ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(neighbor_probs),), dict(axis=(- 1)), fscope)
ValueError: in user code:
File "/tmp/ipykernel_553597/675414708.py", line 77, in train_step *
gather_indices = tf.stack([
ValueError: Shapes must be equal rank, but are 3 and 4
From merging shape 0 with other shapes. for '{{node stack_1}} = Pack[N=2, T=DT_INT32, axis=-1](Tile, Tile_1)' with input shapes: [1,8,3], [1,8,3,3].
最新发布