Tensorflow之根据索引号收集数据
tf.gather
tf.gather 可以实现根据索引号收集数据的目的。考虑班级成绩册的例子,共有4个班,每个班35个学生,8门学科,保存成绩册的张量shape为[4,35,8]。
x = tf.random.uniform([4,35,8],maxval=100,dtype=tf.int32)
现在需要收集1-2班的成绩册,可以给定班级索引号:[0,1]/[:2],班级的维度为axis=0:
In [38]:tf.gather(x,[0,1],axis=0) # 在班级维度收集第1-2 号班级成绩册
Out[38]:<tf.Tensor: id=83, shape=(2, 35, 8), dtype=int32, numpy=
array([[[43, 10, 93, 85, 75, 87, 28, 19],
[52, 17, 44, 88, 82, 54, 16, 65],
[98, 26, 1, 47, 59, 3, 59, 70],…
实际上,对于上述需求,通过切片𝑥[: 2]可以更加方便地实现。但是对于不规则的索引方式,比如,需要抽查所有班级的第1,4,9,12,13,27 号同学的成绩,则切片方式实现起来非常麻烦,而tf.gather 则是针对于此需求设计的,使用起来非常方便:
In [39]: # 收集第1,4,9,12,13,27 号同学成绩
tf.gather(x,[0,3,8,11,12,26],axis=1)
Out[39]:<tf.Tensor: id=87, shape=(4, 6, 8), dtype=int32, numpy=
array([[[43

本文深入探讨了TensorFlow中tf.gather与tf.gather_nd函数的使用技巧,包括如何根据索引号收集多维数据,适用于不规则索引及多维度坐标数据收集场景,极大提升了数据处理效率。
最低0.47元/天 解锁文章
812

被折叠的 条评论
为什么被折叠?



