在学习 CS231n中的NetworkVisualization-PyTorch任务,讲解了使用torch.gather函数,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。
其中 gather有两种使用方式,一种为 torch.gather
另一种为 对象.gather。
首先介绍 对象.gather
torch.gather(input, dim, index)
gather(input, dim, index) 会沿 input 的第 dim 维度,根据 index 张量里的索引值「选取」input 中对应位置的元素。
给定 input 张量、维度 dim、 及一个与 input 在维度数量上相同、且在 dim 维的大小 ≤ input.size(dim) 的 index 张量 ,输出张量 out 与 index 的形状相同,其各位置的值就是 input 在相应索引处的值。

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
index = torch.tensor([[0, 2],
[1, 2]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
#
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor([[ 3, 10],
[ 6, 10]])
import torch
torch.manual_seed(2) #为CPU设置种子用于生成随机数,以使得结果是确定的
def gather_example():
N, C = 4, 5
s = torch.randn(N, C)
y = torch.LongTensor([1, 2, 1, 3]) # 必须要为 LongTensor 不然会报错
print(s)
print(y)
print(s.gather(1,y.view(-1, 1)).squeeze())
gather_example()
'''
输出:
tensor([[-1.0408, 0.9166, -1.3042, -1.1097, 0.0299],
[-0.0498, 1.0651, 0.8860, -0.8110, 0.6737],
[-1.1233, -0.0919, 0.1405, 1.1191, 0.3152],
[ 1.7528, -0.7396, -1.2425, -0.1752, 0.6990]])
tensor([1, 2, 1, 3])
tensor([ 0.9166, 0.8860, -0.0919, -0.1752])
'''
对于上图的代码,首先通过 torch.randn 随机输出化出结果为
tensor([[-1.0408, 0.9166, -1.3042, -1.1097, 0.0299],
[-0.0498, 1.0651, 0.8860, -0.8110, 0.6737],
[-1.1233, -0.0919, 0.1405, 1.1191, 0.3152],
[ 1.7528, -0.7396, -1.2425, -0.1752, 0.6990]])
然后 我们根据索引 tensor([1, 2, 1, 3]) 对每一行进行索引,在第0行索引到位置=1的元素,即 0.9166,在第二行索引到位置=2的元素即 0.8860 以此类推,即为最后的结果。
另一种为 torch.gather
b = torch.Tensor([[1,2,3],[4,5,6]])
print (b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1)) # 按行来进行索引
print (torch.gather(b, dim=0, index=index_2)) # 按列来进行索引
'''
输出为:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
'''
当dim=0时候,那么就是按照行进行索引,输出第一行位置为0,1的元素,即1,2 。第二行位置为2,0的元素,即 6,4。
当dim=1时候,那么就是按照列进行索引,输出第一列位置为0,第二列位置为1,第三列位置为1的元素,即1,5,6。输出第二列位置为0,第二列位置为0,第三列位置为0的元素,即1,2,3。
综上,总结一下,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。你可以选择按照行和列的位置进行索引。
本文详细解析了PyTorch中的gather函数,包括其两种使用方式:对象.gather和torch.gather。通过实例展示了如何根据索引从张量中选取特定元素,适用于行和列的索引操作。
2953

被折叠的 条评论
为什么被折叠?



