tf.gather,取指定维度多个索引的数据

tensorflow和numpy在数据处理上语法相似但又不完全一样,比如在numpy中想取指定维度的多个指定索引所指向的数据时,直接用一个列表保存索引就能直接取,比如:

# b的shape为[2, 3, 2]
b = np.array([[[1, 2], [2, 3], [3,4]], [[4, 5], [5,6], [6, 7]]])
a = [0, 1, 2]
# 假若要取b的第2个维度(从1算起)的以a为索引的数据,只需写成如下形式
b[:, a, :]

但是若b是tensor形式,则上述操作会报错!

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([01, 2])

而在tensor中则需要使用tf.gather方法来实现该功能。
tf.gather可以取指定维度任意索引的数据,举个例子:

b = np.array([[[1, 2], [2, 3], [3,4]], [[4, 5], [5,6], [6, 7]]])
b = tf.constant(b)
a = [0, 1, 2]
# 写法如下,效果和上面的numpy的效果相同
tf.gather(b, axis=1, indices=a)

这是tf.gather的用法,而想来大家也都听过tf.gather_nd,这个功能更加强大,它能对同一个维度的每一组数据分别取不同的索引值,比如取第1行的第二列+第二行的第三列。
关于gather_nd的用法见:tensorflow中tensor,从每行取指定索引元素

当然可以!`tf.gather` 是 TensorFlow 中的一个函数,它用于根据指定索引从一个张量中收集(或提)元素。这个函数在处理稀疏数据、从查找表中获或者重新排列张量的维度时非常有用。 `tf.gather` 的基本定义如下: ``` python tf.gather( params, indices, validate_indices=None, axis=None, batch_dims=0, name=None ) ``` 参数解释: - `params`:一个张量,从中收集元素。 - `indices`:一个整数张量,表示要收集的元素的索引。 - `validate_indices`(可选):布尔,用于检查索引是否越界。默认为 `None`,表示不进行检查。 - `axis`(可选):指定沿哪个轴收集元素。默认为 0。 - `batch_dims`(可选):整数,表示在收集操作之前,`params` 和 `indices` 张量的批处理维度。默认为 0。 - `name`(可选):操作的名称。 `tf.gather` 的返回是一个新的张量,其中包含从 `params` 中收集的元素。 下面是一个简单的例子: ``` python import tensorflow as tf # 创建一个张量 params = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 指定要收集的元素的索引 indices = tf.constant([0, 2]) # 使用 tf.gather 收集元素 result = tf.gather(params, indices, axis=1) # 输出结果 print(result) ``` 输出: ``` tf.Tensor( [[1 3] [4 6] [7 9]], shape=(3, 2), dtype=int32) ``` 在这个例子中,我们沿着第 1 轴(列方向)从 `params` 张量中收集了索引为 0 和 2 的元素。
最新发布
04-25
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值