从一个张量中收集元素
tf.gather_nd(
params,
indices,
name=None
)
举例:
import tensorflow as tf
a=[[[1,2,3],[4,5,6]],[[1,2,3],[-4,-5,-6]]]
x=tf.constant(a)
indices=[[0,0],[1,1]]
y=tf.gather_nd(a,indices)
with tf.Session() as sess:
print(sess.run(y))
这里的输出为:
[[1 2 3]
[-4 -5 -6]]