pytorch之张量切片函数index_select介绍

最近加入了一个deeplearning的学习小组开始学习pytorch,初始对这个向量切片函数index_select()感到有些疑惑,经过自己一番实验之后,应该算是懂了吧,和大家一起分享一下实验结果。

index_select有两种用法,一种是将某一个张量(tensor)作为变量传入torch.index_select()函数,还有一个是tensor的内置方法index_select。用法分别是

a=tensor.tensor([1,2])
# first use
c = torch.index_select(a, 1, torch.tensor([0]))
# second use
d = a.index_select(1 , torch.tensor([0]))

这两种用法的区别和python中的sort以及sorted函数有些类似。第一种用法 c会成为一个新的张量且不会和a共用内存,而第二种用法d会和a共用内存 

然后,如何使用这个切片函数呢?首先,我们看一下官方文档的介绍。

torch.index_select(input, dim, index, out=None) → Tensor

沿着指定维度对输入进行切片,取index中指定的相应项(index为一个LongTensor),然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。

注意: 返回的张量不与原始张量共享内存空间。

参数:

  • input (Tensor) – 输入张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 包含索引下标的一维张量
  • out (Tensor, optional) – 目标张量

其实只要搞清楚这里的 dim和index参数的含义就好。通常我们能够写出来的,直观的数组就是二维数组,可以把它想象成一个表格,行是第零维(注意索引是从0开始的

Pytorch中,可以通过多种方法对Tensor进行切片及索引,从而获取tensor的一部分数据得到新的tensor,以下是一些常见的方法: ### 索引表达式 使用索引表达式可以直接从原有的tensor中得到相应的子tensor。例如,使用基本索引可以获取特定位置的元素,如 `a[2, 2]` 表示获取二维张量 `a` 中第2行第2列的元素 [^1][^3]。 ### 切片操作 切片操作可以获取tensor的连续部分。例如: - `a[1:, :-1]` 表示获取二维张量 `a` 中从第1行开始到最后一行,从第0列开始到倒数第二列的元素。 - `a[::2]` 表示在第0维上从头到尾每隔一个元素进行采样,即获取第0、2、4...行的元素。需要注意的是,PyTorch现在不支持负步长 [^3]。 ### 带步长的切片 可以通过 `start:stop:step` 的形式进行带步长的切片。例如 `0:28:2` 表示从索引0到27,每隔一个元素进行采样;`::2` 表示在相应维度上从头到尾每隔一个元素进行采样 [^5]。 ### 整数索引 通过指定整数列表来选择特定的行和列。例如: ```python import torch a = torch.arange(9).view(3, 3) rows = [0, 1] cols = [2, 2] print(a[rows, cols]) ``` 这段代码表示选择第0行第2列和第1行第2列的元素 [^3]。 ### 布尔索引 使用布尔掩码来选择满足特定条件的元素。例如: ```python import torch a = torch.arange(9).view(3, 3) index = a > 4 print(a[index]) ``` 这段代码会选择张量 `a` 中大于4的元素 [^3]。 ### `torch.nonzero` 索引 使用 `torch.nonzero` 函数可以获取满足特定条件的元素的索引。例如: ```python import torch a = torch.arange(9).view(3, 3) index = torch.nonzero(a >= 8) print(index) ``` 这段代码会获取张量 `a` 中大于等于8的元素的索引 [^3]。 ### `torch.where` 索引 `torch.where` 函数可以根据条件选择元素。例如: ```python import torch x = torch.randn(3, 2) y = torch.ones(3, 2) print(torch.where(x > 0, x, y)) ``` 这段代码会根据条件 `x > 0` 选择 `x` 中的元素,不满足条件的位置选择 `y` 中的元素 [^3]。 ### `masked_select` 索引 使用 `masked_select` 函数可以根据布尔掩码选择元素。例如: ```python import torch a = torch.rand(4, 4) mask = a.ge(0.5) b = a.masked_select(mask) print(b) ``` 这段代码会选择张量 `a` 中大于等于0.5的元素 [^4]。 ### `...` 表示任意多的维度 `...` 表示任意多的维度,可以根据原数据的形状和给定形式推断。例如 `a[0, ...]` 等价于 `a[0, :, :, :]` 或 `a[0]`;`a[:, 1, ...]` 等价于 `a[:, 1, :, :]` 或 `a[:, 1]` [^5]。 ### `tensor.index_select` 接口 可以使用 `tensor.index_select` 接口从原有的tensor中选择特定维度上的元素。例如: ```python import torch a = torch.arange(9).view(3, 3) indices = torch.tensor([0, 2]) b = a.index_select(0, indices) print(b) ``` 这段代码会选择张量 `a` 中第0维上索引为0和2的元素 [^1]。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值