pytorch-全面讲解函数topk, scatter, gather

本文深入解析PyTorch中的topk、gather与scatter函数,详细介绍这些函数如何通过index-value映射方式处理矩阵数据,改变维度并进行排序、取值及替换操作,是掌握PyTorch矩阵操作技巧的必备指南。
部署运行你感兴趣的模型镜像

这三个函数在pytorch中关于矩阵操作的非常实用的函数。我认为要想熟练的使用pytorch,能够灵活的使用这三个函数是至关重要的

三者的相同点:维度->数据的映射方式

因为三者都存在相似的地方,所以我这里放在一起来讲。这个共同点就是index -> value的方式:这里以官方给的gather函数对应为例:

# for a 3-D data
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

这样一看,并不好理解,举个例子:

  • 关于shape的变化:
    • 输入为[3, 4, 1]的数据x | 它的index为[3, 2, 1]
    • x.gather(dim=1, index)输出维度为[3, 2, 1]。它保持另外两维不变,仅在这一维上操作。
  • 关于数据的变化
    • idx中的数据代表在指定维度上的index。 在这里插入图片描述

topk

其实前面讲的映射方式计算起来还是容易乱,不过幸好并不影响我们的使用。emm实在不能理解可以忽略,只需要知道在指定维度上操作即可

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
  • 主要用途:依照大小,从矩阵某维度取值和取索引。常与scatter、gather连用。
  • 函数返回两个变量:value和index
  • 维度变化:假设指定维度为1,则(b, n, m)-> (b, k, m)
  • 其它用途:topk的数据默认按照从大到小排列,因此我们可以当做矩阵中的数据排序来用,若largest=False则为升序:

在这里插入图片描述

gather

torch.gather(input, dim, index, out=None) → Tensor
  • 用途:依照index来对矩阵进行取值
  • 函数返回与输入idx维度相同的tensor
  • 维度变化:假设指定dim=1,index=(b,k,n),input=(b,m,n)。则输出为(b,k,n)。在维度1上按照index进行取值。

scatter

torch.scatter(input, dim, index, src) → Tensor
  • 用途:与gather类似,不过它并不用来取值。scatter用来更替矩阵中指定index位置的值。

  • 维度变化:假设指定dim=1,index=(b,k,n),input=(b,m,n),src=(b,m,n)。则输出为(b,k,n)。在维度1上按照index从src取值,然后替换到input上相同的index位置。

  • 两种用法:

    • 一般要求source的维度为input维度相同,如下例:

    在这里插入图片描述

    • 当然,也可以直接指定要替换的值,如下:

    在这里插入图片描述

您可能感兴趣的与本文相关的镜像

PyTorch 2.8

PyTorch 2.8

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值