Pytorch 源码浅析1.Topk(上)

简介

本章节将介绍笔者学习pytorch源码中遇到的一些并行算法(基于pytorch)。当然这些源码都可以在官方的github上获取。

入口

整体思路

入口为gather_topk函数,首先会经过RadixSelect函数找到第K大小的数(可能不是唯一的)。然后再并行找到大于这个值的所有数并按序前插以确定Index(本节举例最大Topk)。

首先需要了解的前提知识为基数排序,在Pytorch中topk是基于基数排序的变体实现的,了解其原理有助于了解topk的实现。
基数排序

RadixSelect

说这个函数为核心也不为过,这一节主要介绍这个函数。
寻常的基数排序是按位的,而Pytorch中是按两位的。具体在代码中体现为:

#define RADIX_SIZE 4
#define RADIX_BITS 2
#define RADIX_MASK 3

(这里举例最大Topk,且最大topk为按照11,10,01,00的顺序入桶,最小topk反之)

为啥说它是变体呢:topk其实并未将所有数据排序而是分组慢慢挑出最大的K个数据,随之找出这个第K大的数据。

  1. 灵魂函数 : getBitfieldsetBitfield
    如何将最大的k个数慢慢挑出来呢? 这里介绍一下灵魂函数getBitfield、setBitfield
    作用:前者获取当前数据第i和第i+1位的数据(00,01,10,11按两位,其实就是对应基数排序的4个桶,step = RADIX_BITS),很容易理解。

    后者的作用是为了标记我们现在关注的数据,这个可能不太好理解,这里详细讲一下:
    比如现在有1111,1110,1101,1100,0000,0001这6个数而我们现在要找topk3。首先按照基数排序的思路
    最大桶11对应的元素数量为4:1111,1110,1101,1100。此时最大的4个数据大于3(topk3)也就是说我们只需从这4个数中继续找就行了,可以跳过当前剩余的入桶操作,这个setBitdield的作用就是记录这4个数的前缀11帮助我们跳过其他数(之后只找前缀为11的数)。在代码中具体表现为desireddesired_mask。这里简单提一下,后面会举例详细讲。
    源码

static __device__ __forceinline__
    unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
      pos &= 0xff;
      len &= 0xff;

      unsigne
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值