Torch.gather() 和 Tensor.gather() 的说明及示例
1. 函数原型
Tensor.gather(dim, index) 和 torch.gather(input, dim, index, …)使用方法相似
下面主要以torch.gather()来说明。
out=torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
1.1 定义:官网链接 https://pytorch.org/docs/stable/generated/torch.gather.html
#官网给出三维数据的说明
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k