选择函数 torch.gather()的理解

该博客主要介绍了Pytorch中的torch.gather函数。说明了其函数形式torch.gather(input, dim, index, out=None) → Tensor,返回的tensor的size与index的size一致,dim用于指明index元素值代表的维数,还给出了其官方解释及3维张量的输出定义。
部署运行你感兴趣的模型镜像

【时间】2019.03.19

【题目】选择函数 torch.gather()的理解

1、Pytorch中的torch.gather函数的含义
2、pytorch之torch.gather方法

torch.gather(input, dim, index, out=None) → Tensor

【注意】

返回的tensor的size与index的size一致。

dim用于指明index的元素值代表的维数。这个函数可以用来很方便地提取方阵

的对角元素。比如:

import torch as t
a = t.arange(0, 16).view(4, 4)

index = t.LongTensor([[0,1,2,3]])

b = t.gather(a,0, index)
print(a)
print(index)
print(b)

【运行结果】:表示从a中提取shape为index的shape(即1X2)的tensor,dim = 0指明index元素值的维度,所以所选择的完整索引值为[[(0,0),(1,1),(2,2),(3,3)]],即最终结果b取a的对角线。


官方给出的解释是这样的:
                    沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
                    对一个3维张量,输出可以定义为:

                out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0

                out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1

                out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值