pytorch中的gather和scatter函数

这篇博客主要介绍了PyTorch中的gather和scatter函数。gather函数沿着指定维度聚合输入张量的值,适用于KNN分类问题中获取最相似样本的标签。scatter函数则根据输入索引修改目标张量的指定位置,可用于生成one-hot标签。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近看代码时遇到了两个函数,查阅pytorch官方文档后一时半会儿也没弄懂,现在写篇笔记来加深一下印象。

gather

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
沿着给定的维度dim,将输入input指定位置的值聚合起来,指定位置由index决定。
indexinput必须有相同数量的维度,且满足1 <= index[dim] <= input[dim]index[other_dims] == input[other_dims]
对于3维的张量,公式为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

下面给出一个例子:

import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值