Pytorch中gather(1,)和.gather(0,)的区别?
在 PyTorch 中,.gather(dim, index) 函数用于根据给定的索引在指定的维度上获取张量的元素。其中,dim 表示要进行索引的维度,index 是包含索引值的一个张量。
下面以一个简单的例子来解释 .gather(1,) 和 .gather(0,) 的区别:
import torch
# 构造一个 3x4 的 Tensor,每个元素的值为对应位置的行列索引之和
x = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]])
# 沿着行的方向使用 gather,将第 i 行的第 idx[i] 个元素取出来
idx = torch.tensor([1, 2, 0])
result = x.gather(1, idx.view(-1, 1))
print(result) # prints: tensor([[1], [3], [2]])
# 沿着列的方向使用 gather,将第 j 列的第 idx[j] 个元素取出来
idx = torch.tensor([1, 2, 0, 3])
result = x.gather(0, idx.view(1, -1))
print(result) # prints: tensor([[1, 3, 2, 3]])
在上面的例子中,我们首先构造了一个 3x4 的 Tensor x,每个元素的值为对应位置的行列索引之和。然后,我们分别使用 .gather(1, index) 和 .gather(0, index) 对其进行索引。
.gather(1, index) 表示沿着行的方向使用 gather,idx 张量的每个元素表示对应行的要取出的元素的列索引,dim=1 表示对列进行操作,那么返回的结果是一个列向量,它的第
i
i
i 个元素是第
i
i
i 行的第
i
d
x
[
i
]
idx[i]
idx[i] 个元素。例如,当 idx=[1, 2, 0] 时,结果应该是 [1, 3, 2]。
.gather(0, index) 表示沿着列的方向使用 gather,idx 张量的每个元素表示对应列的要取出的元素的行索引,dim=0 表示对行进行操作,那么返回的结果是一个行向量,它的第
j
j
j 个元素是第
i
d
x
[
j
]
idx[j]
idx[j] 列的第
j
j
j 个元素。例如,当 idx=[1, 2, 0, 3] 时,结果应该是 [1, 3, 2, 3]。
总之,.gather(dim, index) 函数可以根据指定的索引在指定的维度上对张量进行索引,但需要注意 dim 和 index 的值对于结果的影响。