结论:使用方法
# gather,沿dim指定的轴收集值。
y_hat.gather(1, y.view(-1, 1))# y.view(-1, 1)会变成一列,y_hat的取y作为的索引的值
分步理解:先创建一个2*3的tensor
>>y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
tensor([[0.1000, 0.3000, 0.6000],
[0.3000, 0.2000, 0.5000]])
为了使用gather函数,我们得创建一个tensor作为gather得参数
>>y = torch.LongTensor([0, 2])
tensor([0, 2])
我们需要把y变个形状
>>y.view(-1, 1)
tensor([[0],
[2]])
先来看看使用得结果
>>y_hat.gather(1, y.view(-1, 1))
tensor([[0.1000],
[0.5000]])
图解:

本文详细介绍了PyTorch中gather函数的应用方法。通过创建一个2*3的张量并利用另一个张量作为索引,展示了如何沿指定维度收集元素。通过具体实例帮助读者更好地理解gather函数的工作原理。
1301

被折叠的 条评论
为什么被折叠?



