【时间】2019.03.19
【题目】选择函数 torch.gather()的理解
1、Pytorch中的torch.gather函数的含义
2、pytorch之torch.gather方法
torch.gather(input, dim, index, out=None) → Tensor
【注意】
返回的tensor的size与index的size一致。
dim用于指明index的元素值代表的维数。这个函数可以用来很方便地提取方阵
的对角元素。比如:
import torch as t
a = t.arange(0, 16).view(4, 4)
index = t.LongTensor([[0,1,2,3]])
b = t.gather(a,0, index)
print(a)
print(index)
print(b)
【运行结果】:表示从a中提取shape为index的shape(即1X2)的tensor,dim = 0指明index元素值的维度,所以所选择的完整索引值为[[(0,0),(1,1),(2,2),(3,3)]],即最终结果b取a的对角线。

官方给出的解释是这样的:
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
该博客主要介绍了Pytorch中的torch.gather函数。说明了其函数形式torch.gather(input, dim, index, out=None) → Tensor,返回的tensor的size与index的size一致,dim用于指明index元素值代表的维数,还给出了其官方解释及3维张量的输出定义。
1701

被折叠的 条评论
为什么被折叠?



