这三个函数在pytorch中关于矩阵操作的非常实用的函数。我认为要想熟练的使用pytorch,能够灵活的使用这三个函数是至关重要的
三者的相同点:维度->数据的映射方式
因为三者都存在相似的地方,所以我这里放在一起来讲。这个共同点就是index -> value的方式:这里以官方给的gather函数对应为例:
# for a 3-D data
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
这样一看,并不好理解,举个例子:
- 关于shape的变化:
- 输入为[3, 4, 1]的数据x | 它的index为[3, 2, 1]
- x.gather(dim=1, index)输出维度为[3, 2, 1]。它保持另外两维不变,仅在这一维上操作。
- 关于数据的变化
- idx中的数据代表在指定维度上的index。
- idx中的数据代表在指定维度上的index。
topk
其实前面讲的映射方式计算起来还是容易乱,不过幸好并不影响我们的使用。emm实在不能理解可以忽略,只需要知道在指定维度上操作即可
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
- 主要用途:依照大小,从矩阵某维度取值和取索引。常与scatter、gather连用。
- 函数返回两个变量:value和index
- 维度变化:假设指定维度为1,则(b, n, m)-> (b, k, m)
- 其它用途:topk的数据默认按照从大到小排列,因此我们可以当做矩阵中的数据排序来用,若largest=False则为升序:
gather
torch.gather(input, dim, index, out=None) → Tensor
- 用途:依照index来对矩阵进行取值
- 函数返回与输入idx维度相同的tensor
- 维度变化:假设指定dim=1,index=(b,k,n),input=(b,m,n)。则输出为(b,k,n)。在维度1上按照index进行取值。
scatter
torch.scatter(input, dim, index, src) → Tensor
-
用途:与gather类似,不过它并不用来取值。scatter用来更替矩阵中指定index位置的值。
-
维度变化:假设指定dim=1,index=(b,k,n),input=(b,m,n),src=(b,m,n)。则输出为(b,k,n)。在维度1上按照index从src取值,然后替换到input上相同的index位置。
-
两种用法:
- 一般要求source的维度为input维度相同,如下例:
- 当然,也可以直接指定要替换的值,如下: