torch.gather()总结

本文详细介绍了PyTorch中的gather函数,通过多个二维和三维案例解释了如何根据索引在不同维度上选取值。在二维情况下,dim=0表示选取列,dim=1表示选取行;在三维情况下,dim=0表示选取页,dim=1表示选取列,dim=2表示选取行。案例涵盖了选取单个值和多个值的情况,展示了gather在张量操作中的灵活性。
部署运行你感兴趣的模型镜像

torch.gather沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.

1. 二维情况下

(1)case1: dim=0

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
       			  [ 6,  7,  8],
       			  [ 9, 10, 11]])
       			  
tensor_1 tensor([[9, 7, 5]])
# note:  dim=0从列里面选,【9】是第一列中第2个数,【7】是第二列第1个数,【5】是第三列第0个数

(2)case2: dim=1

tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2],[1],[0]])
tensor_1 = tensor_0.gather(1, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
		       	  [ 6,  7,  8],
		          [ 9, 10, 11]])
tensor_1 tensor([[5],
		      	 [7],
		    	 [9]])
# note: dim=1, 从行里取【5】是第一行第二个数,【7】是第二行第1个数,【9】是第三行第0个数

case3: 一行中取多个数

tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2,1,0],[1,1,1],[0,1,0]])
tensor_1 = tensor_0.gather(1, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
		          [ 6,  7,  8],
		          [ 9, 10, 11]])
tensor_1 tensor([[ 5,  4,  3],
		       	[ 7,  7,  7],
		       	[ 9, 10,  9]])

Case4: 一列中取多个数

    tensor_0 = torch.arange(3, 12).view(3, 3)
    index = torch.tensor([[2,1,0],[1,1,1],[0,1,0]])
    tensor_1 = tensor_0.gather(0, index)
    print("tensor_0:", tensor_0)
    print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
		          [ 6,  7,  8],
		          [ 9, 10, 11]])
tensor_1 tensor([[9, 7, 5],
		       	[6, 7, 8],
		        [3, 7, 5]])

2. 三维情况

case1: dim=1

    a = torch.randint(0, 30, (2, 3, 5))
    index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
    b = torch.gather(a, 1, index)
    print("a:", a)
    print("b:", b)

输出:

a: tensor([[[13,  1, 25, 18, 28],
         	[24, 19,  5, 25, 11],
         	[13, 13, 20,  9, 22]],

           [[22, 18, 12,  9,  1],
         	[ 6, 11, 23, 11, 29],
        	[15,  9,  8, 29,  6]]])
        	
b: tensor([[[13, 19, 20, 18, 22],
         	[13,  1, 25, 18, 28],
         	[24, 19,  5, 25, 11]],

           [[ 6,  9,  8, 29,  6],
         	[22, 18, 12,  9,  1],
         	[15,  9,  8, 29,  6]]])
         	
 # note: dim=1从列中取

case2: dim=2

    a = torch.randint(0, 30, (2, 3, 5))
    index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
    b = torch.gather(a, 2, index)
    print("a:", a)
    print("b:", b)

输出

a: tensor([[[ 0, 19,  3, 20, 29],
         [ 4,  2,  1,  8, 13],
         [16, 15, 13, 29, 10]],

        [[25, 18, 16,  0,  6],
         [ 3,  4, 13, 23, 19],
         [ 7, 21, 28, 17, 11]]])
         
b: tensor([[[ 0, 19,  3,  0,  3],
         [ 4,  4,  4,  4,  4],
         [15, 15, 15, 15, 15]],

        [[18, 16, 16, 16, 16],
         [ 3,  3,  3,  3,  3],
         [28, 28, 28, 28, 28]]])
#  dim=2 从行里取数

case3: dim=0

    a = torch.randint(0, 30, (2, 3, 5))
    index = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
    b = torch.gather(a, 0, index)
    print("a:", a)
    print("b:", b)

输出

a: tensor([[[ 9,  3, 10, 19,  4],
         	[26, 19, 20,  9, 28],
         	[ 5, 21, 29, 26, 24]],

           [[10,  2, 11, 29, 26],
         	[20, 25, 17, 11, 16],
         	[ 4, 17, 27, 17, 29]]])
         
b: tensor([[[ 9,  2, 11, 19, 26],
         	[26, 25, 17, 11, 16],
         	[ 4, 17, 27, 17, 29]],

           [[10,  3, 10, 19,  4],
         	[26, 19, 20,  9, 28],
         	[ 4, 17, 29, 26, 24]]])
# dim = 0时,索引代表在第几页取数,取数的位置为索引i所在的坐标,如上:index[0][0][0]=0,表示取a中第0页(0,0)的数9,index[0][0][1]=1表示取第1页的(0,1)坐标的数3

在这里插入图片描述

总结

  • index的维数必须与输入维数相同,输入为2维矩阵,index也必须为2维矩阵
  • 在二维矩阵中dim=0 表示列,dim=1表示行,三维矩阵中,dim=0表示页,dim=1表示列,dim=2表示行

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

torch.gatherPyTorch 中的一个用于从给定维度上按索引取值的函数,根据一个索引张量`index`,从源张量`input`中收集值,并返回一个新的张量,常用于从张量的特定位置抽取元素的操作。 ### 函数签名 ```python torch.gather(input, dim, index, *, sparse_grad=False, out=None) ``` - `input`:输入张量,表示要从中收集元素的源张量。 - `dim`:要收集的维度索引。例如,对于一个二维张量,0 表示沿着行的维度,1 表示沿着列的维度。 - `index`:索引张量,其形状应与`input`张量在除了`dim`维度之外的其他维度上保持一致。索引张量中的值表示在`input`张量对应维度上要收集的元素的索引。 - `out`(可选):输出张量,如果提供,结果将存储在这个张量中。 ### 工作原理 该函数按照指定的维度`dim`和索引`index`从输入张量`input`中收集数值。`index`张量中的每个元素指定了在`input`张量中`dim`维度上的位置。根据`index`张量中的索引,在`input`张量中沿着`dim`维度收集元素,输出张量的形状与`index`张量的形状相同,这意味着除了`dim`维度之外,其他所有维度的大小都与`index`相同 [^1][^2]。 ### 使用示例 ```python import torch # 创建一个输入张量 input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 创建一个索引张量,其形状与输入张量相同 index_tensor = torch.tensor([[0, 2, 1], [2, 0, 1], [1, 0, 2]]) # 使用 torch.gather 收集元素,沿着列(dim=1) output_tensor = torch.gather(input_tensor, 1, index_tensor) print(output_tensor) ``` 另外,还有如下示例: ```python t = torch.Tensor([[1, 2], [3, 4]]) result = torch.gather(t, 1, torch.LongTensor([[0, 0], [1, 0]])) print(result) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值