Pytorch中torch.gather函数

本文详细解析了PyTorch中的gather函数,包括其两种使用方式:对象.gather和torch.gather。通过实例展示了如何根据索引从张量中选取特定元素,适用于行和列的索引操作。

在学习 CS231n中的NetworkVisualization-PyTorch任务,讲解了使用torch.gather函数,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。
其中 gather有两种使用方式,一种为 torch.gather
另一种为 对象.gather。

首先介绍 对象.gather

torch.gather(input, dim, index)
gather(input, dim, index) 会沿 input 的第 dim 维度,根据 index 张量里的索引值「选取」input 中对应位置的元素。
给定 input 张量、维度 dim、 及一个与 input 在维度数量上相同、且在 dim 维的大小 ≤ input.size(dim) 的 index 张量 ,输出张量 out 与 index 的形状相同,其各位置的值就是 input 在相应索引处的值。
在这里插入图片描述

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
index = torch.tensor([[0, 2],
                      [1, 2]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

#
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
tensor([[ 3, 10],
        [ 6, 10]])
import torch

torch.manual_seed(2)   #为CPU设置种子用于生成随机数,以使得结果是确定的 
def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C)
    y = torch.LongTensor([1, 2, 1, 3])  # 必须要为 LongTensor 不然会报错
    print(s)
    print(y)
    print(s.gather(1,y.view(-1, 1)).squeeze())
    

gather_example()


'''
输出:

tensor([[-1.0408,  0.9166, -1.3042, -1.1097,  0.0299],
        [-0.0498,  1.0651,  0.8860, -0.8110,  0.6737],
        [-1.1233, -0.0919,  0.1405,  1.1191,  0.3152],
        [ 1.7528, -0.7396, -1.2425, -0.1752,  0.6990]])
tensor([1, 2, 1, 3])
tensor([ 0.9166,  0.8860, -0.0919, -0.1752])

'''

对于上图的代码,首先通过 torch.randn 随机输出化出结果为

tensor([[-1.0408,  0.9166, -1.3042, -1.1097,  0.0299],
        [-0.0498,  1.0651,  0.8860, -0.8110,  0.6737],
        [-1.1233, -0.0919,  0.1405,  1.1191,  0.3152],
        [ 1.7528, -0.7396, -1.2425, -0.1752,  0.6990]])

然后 我们根据索引 tensor([1, 2, 1, 3]) 对每一行进行索引,在第0行索引到位置=1的元素,即 0.9166,在第二行索引到位置=2的元素即 0.8860 以此类推,即为最后的结果。

另一种为 torch.gather

b = torch.Tensor([[1,2,3],[4,5,6]])
print (b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1)) # 按行来进行索引
print (torch.gather(b, dim=0, index=index_2)) # 按列来进行索引


'''
输出为:

1  2  3
 4  5  6
[torch.FloatTensor of size 2x3]


 1  2
 6  4
[torch.FloatTensor of size 2x2]


 1  5  6
 1  2  3
[torch.FloatTensor of size 2x3]

'''

当dim=0时候,那么就是按照行进行索引,输出第一行位置为0,1的元素,即1,2 。第二行位置为2,0的元素,即 6,4。

当dim=1时候,那么就是按照列进行索引,输出第一列位置为0,第二列位置为1,第三列位置为1的元素,即1,5,6。输出第二列位置为0,第二列位置为0,第三列位置为0的元素,即1,2,3。

综上,总结一下,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。你可以选择按照行和列的位置进行索引。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值