torch.gather函数的简单理解与使用

功能:根据索引来对高维tensor进行选择
要求:

  • input tensor 与 index 的 dim一致
  • index.shape < input.shape
torch.gather(input, dim, index) → Tensor

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
import torch

# [[1,2,3],
#  [4,5,6],
#  [7,8,9]]
input = torch.range(1, 9).view(3, 3)

示例1

#====================== dim=0 索引======================#
# 表示提供采样索引维度的是dim=0
dim = 0		
# index的维度为 [1, 3], 说明输出维度也为 [1, 3], index每个元素的索引为 [[(0, 0), (0, 1), (0, 2)]]
index = torch.tensor([[2, 1, 0]])
# 用index 来替换index的索引列表中的dim=0 得到: [[(2, 0), (1, 1), (0, 2)]]
output = torch.gather(input, dim, index)
# 将input索引为 (2, 0), (1, 1), (0, 3) 取出来就是 [[7, 5, 3]] 	

在这里插入图片描述


示例2

#======================== dim=1 索引=====================#
# 表示提供的采样索引维度dim=1
dim = 1
# index的维度为 (1, 3), 也就是index每个元素的索引为    [[(0, 0), (0, 1), (0, 2)]], 
index = torch.tensor([[2, 1, 0]])
# 用index的取值来替代 index的索引列表中的dim=1的元素得:[[(0,2) (0,1) (0,0)]]
# 将input索引为 (0,2) (0,1) (0,0) 取出来就是[[3, 2, 1]]
optput = torch.gather(input, dim, index)

在这里插入图片描述


示例3

#=============================================#
dim = 1	# 表示采样索引为1
index = torch.tensor([[2],
		 			  [1],
		 			  [0]])
# index的索引为[(0, 0),
			   (1, 0),
			   (2, 0)]
# 使用index的其余维索引来补全后得到:
'''
 [(0, 2),
  (1, 1),
  (2, 0)]
'''
# 对input索引
output = torch.gather(input, dim, index)
'''
[[3],
 [5],
 [7]]
'''

在这里插入图片描述


示例4

#====================== 多维index =====================#
dim = 1
index = torch.tensor([[0, 2], 
                      [1, 2]])
# index 的索引为 [[(0,0), (0,1)], 
# 				 [(1,0), (1,1)]]
# 用除了1维以外的索引将index补全得:
# [[(0,0), (0,2)], 
#  [(1,1), (1,2)]]

# 对input索引
output = torch.gather(input, dim, index)		  
#[[0, 3]
# [5, 6]]

在这里插入图片描述

更高维度的gather索引也是如此,先生成index每个元素的索引,再用index的值来替代dim维度的索引值,最后按照索引值到input中索引得到output

### 回答1: torch.gather函数是PyTorch中的一个函数,用于在给定维度上按索引从输入张量中提取元素并构建新的张量。 torch.gather函数的语法为:torch.gather(input, dim, index, out=None)。 参数说明: - input:输入张量,即需要从中提取元素的张量。 - dim:要在哪个维度上进行提取操作。 - index:一个包含需要提取元素的索引的张量。 - out:一个可选的输出张量。 在torch.gather函数中,我们会按照dim指定的维度,在input张量上进行提取操作。提取操作是根据index张量中给定的索引值来进行的。最终会构建一个新的张量,其中包含了根据索引从input张量中提取出来的元素。 例如,如果input是一个2维张量,shape为(3,4),而index是一个1维张量,shape为(3,),则dim的取值范围为[0, 1]。如果dim=0,那么提取操作将沿着第一个维度进行,在每一列上按照index张量中对应的值进行元素的提取。如果dim=1,那么提取操作将沿着第二个维度进行,在每一行上按照index张量中对应的值进行元素的提取。 使用torch.gather函数可以灵活地根据给定的索引从输入张量中提取出所需的元素,这对于实现一些特定需求的操作非常有用。例如,可以在处理图像分类任务时,根据预测的类别标签,从softmax输出概率中提取出对应类别的概率,进而用于计算损失函数或者评估模型性能等。 ### 回答2: torch.gather函数是一个PyTorch中的操作函数,用于在指定维度上根据索引获取原始张量中的元素。这个函数使用方式为: output = torch.gather(input, dim, index, out=None, sparse_grad=False) 其中,input是原始的张量,dim是指定的维度,index是需要提取的元素的索引。函数会根据dim指定的维度,在input张量中提取index中指定的元素,并返回一个新的张量output。 例如,假设input是一个3x4的二维张量,index是一个2x3的二维张量,dim的取值为1,那么torch.gather函数会在input的第1个维度上根据index中的元素索引,提取相应的元素。最终得到的output是一个2x3的张量。 torch.gather函数在很多机器学习任务中非常有用。例如,在序列标注任务中,我们可以使用torch.gather函数根据标签索引来选择对应的预测结果。在图像分类任务中,我们可以根据类别索引使用torch.gather函数进行结果的选择。此外,在自然语言处理任务中,torch.gather函数也可以用来根据单词的索引来选择对应的词向量。 需要注意的是,所提取的元素的维度必须index的维度一致,否则会引发异常。此外,dim的取值必须在0到input的维度之间,否则也会引发异常。如果不指定out参数,函数会返回一个新的张量作为输出,如果指定了out参数,则会把提取的结果保存到指定的张量中。最后,如果sparse_grad为True,则会返回一个稀疏梯度,否则返回一个密集梯度。 总之,torch.gather函数提供了一种方便和高效地根据索引提取元素的方式,广泛应用于各种机器学习任务中。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CV科研随想录

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值