Pytorch:.gather(1,)和.gather(0,)的区别

本文通过一个例子说明了在PyTorch中,.gather(1,)和.gather(0,)的区别。gather函数用于按指定维度根据索引选取张量元素。当dim=1时,它沿行选取元素,返回列向量;当dim=0时,它沿列选取元素,返回行向量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch中gather(1,)和.gather(0,)的区别?

在 PyTorch 中,.gather(dim, index) 函数用于根据给定的索引在指定的维度上获取张量的元素。其中,dim 表示要进行索引的维度,index 是包含索引值的一个张量。

下面以一个简单的例子来解释 .gather(1,) 和 .gather(0,) 的区别:

import torch

# 构造一个 3x4 的 Tensor,每个元素的值为对应位置的行列索引之和
x = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]])

# 沿着行的方向使用 gather,将第 i 行的第 idx[i] 个元素取出来
idx = torch.tensor([1, 2, 0])
result = x.gather(1, idx.view(-1, 1))
print(result)  # prints: tensor([[1], [3], [2]])

# 沿着列的方向使用 gather,将第 j 列的第 idx[j] 个元素取出来
idx = torch.tensor([1, 2, 0, 3])
result = x.gather(0, idx.view(1, -1))
print(result)  # prints: tensor([[1, 3, 2, 3]])

在上面的例子中,我们首先构造了一个 3x4 的 Tensor x,每个元素的值为对应位置的行列索引之和。然后,我们分别使用 .gather(1, index) 和 .gather(0, index) 对其进行索引。

.gather(1, index) 表示沿着行的方向使用 gather,idx 张量的每个元素表示对应行的要取出的元素的列索引,dim=1 表示对列进行操作,那么返回的结果是一个列向量,它的第 i i i 个元素是第 i i i 行的第 i d x [ i ] idx[i] idx[i] 个元素。例如,当 idx=[1, 2, 0] 时,结果应该是 [1, 3, 2]。
.gather(0, index) 表示沿着列的方向使用 gather,idx 张量的每个元素表示对应列的要取出的元素的行索引,dim=0 表示对行进行操作,那么返回的结果是一个行向量,它的第 j j j 个元素是第 i d x [ j ] idx[j] idx[j] 列的第 j j j 个元素。例如,当 idx=[1, 2, 0, 3] 时,结果应该是 [1, 3, 2, 3]。
总之,.gather(dim, index) 函数可以根据指定的索引在指定的维度上对张量进行索引,但需要注意 dim 和 index 的值对于结果的影响。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值