pytorch中的gather函数_Pytorch中的torch.gather函数

本文详细介绍了PyTorch中的torch.gather函数,通过实例演示了如何使用该函数根据指定轴对张量进行元素检索,并展示了在不同维度下索引操作的结果。从矩阵操作到深度学习中的应用,为理解张量索引提供了清晰指南。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

gather函数的的官方文档:

torch.gather(input, dim, index, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # dim=0

out[i][j][k] = input[i][index[i][j][k]][k] # dim=1

out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

Parameters:

input (Tensor) – The source tensor

dim (int) – The axis along which to index

index (LongTensor) – The indices of elements to gather

out (Tensor, optional) – Destination tensor

Example:

>>> t = torch.Tensor([[1,2],[3,4]])

>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))

1 1

4 3

[torch.FloatTensor of size 2x2]

例子:

a=t.arange(0,16).view(4,4)

print(a)

index_1=t.LongTensor([[3,2,1,0]])

b=a.gather(0,index_1)

print(b)

index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()

c=a.gather(1,index_2)

print(c)

输出如下:

tensor([[ 0, 1, 2, 3],

[ 4, 5, 6, 7],

[ 8, 9, 10, 11],

[12, 13, 14, 15]])

tensor([[12, 9, 6, 3]])

tensor([[ 0],

[ 5],

[10],

[15]])

在上面的例子中,a是一个4×4矩阵:

1)当维度dim=0,索引index_1为[3,2,1,0]时,此时可将a看成1×4的矩阵,通过index_1对a每列进行行索引:第一列第四行元素为12,第二列第三行元素为9,第三列第二行元素为6,第四列第一行元素为3,即b=[12,9,6,3];

2)当维度dim=1,索引index_2为[0,1,2,3]T时,此时可将a看成4×1的矩阵,通过index_1对a每行进行列索引:第一行第一列元素为0,第二行第二列元素为5,第三行第三列元素为10,第四行第四列元素为15,即c=[0,5,10,15]T。

例子二:

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

y = torch.LongTensor([0, 2])

y_hat.gather(1, y.view(-1, 1))

输出:

tensor([[0.1000],

[0.5000]])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值