Pytorch 中的 gather 解释和实例

本文深入解析PyTorch中gather函数的使用方法,通过实例演示如何沿指定维度进行数据选择,特别聚焦于目标检测场景下,如何利用gather从预测结果中挑选出置信度最高的前K个预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch 中的 gather 解释和实例

前言

这是我的第一篇文章,想写博客主要想纪录一下自己的学习历程,方便之后回来复习能更快速理解,也在撰写的过程中来检视自己是否确实学会了,也希望能帮助到其他想学习的同学,这篇主要纪录了torcg.gather()的使用,自己本身也花了一些时间来理解,后面也会附上一个实用的例子.

torch.gather()

首先附上官方文档 : torch.gather() ,文档内容就不浪费篇幅来贴了,能看懂的话应该也不会点进来这篇了.其实这个东西就是依据你的dim去作选择,例如:
下面第一个例子,我们选的dim = 0, 也就是沿着axis = 0这个方向去做选择,像是index里面的[2,2,1,0] 就是分别对应到:
第一个column的第二个值(index),
第二个column的第二个值(index),
第三个column的第一个值(index),
第四个column的第零个值(index),也就是[63, 69, 90, 68].
那么像第二个例子,我们选择的dim = 1,也就是说明我们沿着dim=1这个方向来选择,像是[2,0,1] 就是分别在第一个(index0)row里去找[2,0,1] ,因此也就对应到[26,93,32],其余的就自己练习看看了.

 t = torch.randint(0,100,(3,4))
 t
tensor([[93, 32, 26, 68],
        [68, 17, 90, 33],
        [63, 69, 19, 61]])
torch.gather(t,dim = 0,index = torch.tensor([[2,2,1,0],[1,0,2,2]]))
tensor([[63, 69, 90, 68],
        [68, 32, 19, 61]])
     
>>> torch.gather(t,dim = 1,index = torch.tensor([[2,0,1],[3,0,1],[1,2,3]]))
tensor([[26, 93, 32],
        [33, 68, 17],
        [69, 19, 61]])

实用范例

附上一个最近遇到的例子,假设我们现在做目标检测,神经网路输出了box的预测以及class的预测,那如果我们取得依据他的信心值取得前面K个(top k) 预测之结果可以怎么做呢?

import torch

#set seed 
torch.manual_seed(0)

#for object detection ,ouput would be  box_preds and cls_preds 
#In this case , we have 10 preditions and the number of classes = 3
batch_size = 2 
cls_num = 3
box_preds =  torch.randint(0,400,(batch_size,10,4))
cls_preds =  torch.randn((batch_size,10,cls_num))

#select topk 
#assume there're lots of predictions,we only select top k to predict
#to make it simple we only have 10 preds ,and pick top 3 by it's confidence score 

_,top_K_cls =  torch.topk(cls_preds.reshape(batch_size,-1),dim = 1,k = 3 )

#If you flatten the vector ,  you can get it's row by using / operation
#If you flatten the vector ,  you can get it's col by using % operation 
row_wise = top_K_cls/cls_num 
col_wise = top_K_cls%cls_num

#select topk box 

top_K_boxes = torch.gather(box_preds,dim = 1 ,index = row_wise.unsqueeze(2).expand(-1,-1,4))

top_K_cls = torch.gather(cls_preds,dim = 1,index = row_wise.unsqueeze(2).expand(-1,-1,cls_num))

top_K_cls = torch.gather(top_K_cls,dim = 2,index = col_wise.unsqueeze(2))

总结

这篇稍微纪录了一下我怎么去理解这个函数的使用方法,个人觉得这个函数刚开始其实不太容易理解,我也是看了一些文后,慢慢有些理解,其实没看懂也没关系,可以多看一些文每个,人的理解方式都不太一样,可以确定的是,实际实作了几次后就会慢慢抓到感觉了.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值