Pytorch 中的 gather 解释和实例
前言
这是我的第一篇文章,想写博客主要想纪录一下自己的学习历程,方便之后回来复习能更快速理解,也在撰写的过程中来检视自己是否确实学会了,也希望能帮助到其他想学习的同学,这篇主要纪录了torcg.gather()
的使用,自己本身也花了一些时间来理解,后面也会附上一个实用的例子.
torch.gather()
首先附上官方文档 : torch.gather() ,文档内容就不浪费篇幅来贴了,能看懂的话应该也不会点进来这篇了.其实这个东西就是依据你的dim去作选择,例如:
下面第一个例子,我们选的dim = 0, 也就是沿着axis = 0这个方向去做选择,像是index里面的[2,2,1,0] 就是分别对应到:
第一个column的第二个值(index),
第二个column的第二个值(index),
第三个column的第一个值(index),
第四个column的第零个值(index),也就是[63, 69, 90, 68].
那么像第二个例子,我们选择的dim = 1,也就是说明我们沿着dim=1这个方向来选择,像是[2,0,1] 就是分别在第一个(index0)row里去找[2,0,1] ,因此也就对应到[26,93,32],其余的就自己练习看看了.
t = torch.randint(0,100,(3,4))
t
tensor([[93, 32, 26, 68],
[68, 17, 90, 33],
[63, 69, 19, 61]])
torch.gather(t,dim = 0,index = torch.tensor([[2,2,1,0],[1,0,2,2]]))
tensor([[63, 69, 90, 68],
[68, 32, 19, 61]])
>>> torch.gather(t,dim = 1,index = torch.tensor([[2,0,1],[3,0,1],[1,2,3]]))
tensor([[26, 93, 32],
[33, 68, 17],
[69, 19, 61]])
实用范例
附上一个最近遇到的例子,假设我们现在做目标检测,神经网路输出了box的预测以及class的预测,那如果我们取得依据他的信心值取得前面K个(top k) 预测之结果可以怎么做呢?
import torch
#set seed
torch.manual_seed(0)
#for object detection ,ouput would be box_preds and cls_preds
#In this case , we have 10 preditions and the number of classes = 3
batch_size = 2
cls_num = 3
box_preds = torch.randint(0,400,(batch_size,10,4))
cls_preds = torch.randn((batch_size,10,cls_num))
#select topk
#assume there're lots of predictions,we only select top k to predict
#to make it simple we only have 10 preds ,and pick top 3 by it's confidence score
_,top_K_cls = torch.topk(cls_preds.reshape(batch_size,-1),dim = 1,k = 3 )
#If you flatten the vector , you can get it's row by using / operation
#If you flatten the vector , you can get it's col by using % operation
row_wise = top_K_cls/cls_num
col_wise = top_K_cls%cls_num
#select topk box
top_K_boxes = torch.gather(box_preds,dim = 1 ,index = row_wise.unsqueeze(2).expand(-1,-1,4))
top_K_cls = torch.gather(cls_preds,dim = 1,index = row_wise.unsqueeze(2).expand(-1,-1,cls_num))
top_K_cls = torch.gather(top_K_cls,dim = 2,index = col_wise.unsqueeze(2))
总结
这篇稍微纪录了一下我怎么去理解这个函数的使用方法,个人觉得这个函数刚开始其实不太容易理解,我也是看了一些文后,慢慢有些理解,其实没看懂也没关系,可以多看一些文每个,人的理解方式都不太一样,可以确定的是,实际实作了几次后就会慢慢抓到感觉了.