简介
本章节将介绍笔者学习pytorch源码中遇到的一些并行算法(基于pytorch)。当然这些源码都可以在官方的github上获取。
整体思路
入口为gather_topk函数,首先会经过RadixSelect函数找到第K大小的数(可能不是唯一的)。然后再并行找到大于这个值的所有数并按序前插以确定Index(本节举例最大Topk)。
首先需要了解的前提知识为基数排序,在Pytorch中topk是基于基数排序的变体实现的,了解其原理有助于了解topk的实现。
基数排序
RadixSelect
说这个函数为核心也不为过,这一节主要介绍这个函数。
寻常的基数排序是按位的,而Pytorch中是按两位的。具体在代码中体现为:
#define RADIX_SIZE 4
#define RADIX_BITS 2
#define RADIX_MASK 3
(这里举例最大Topk,且最大topk为按照11,10,01,00的顺序入桶,最小topk反之)
为啥说它是变体呢:topk其实并未将所有数据排序而是分组慢慢挑出最大的K个数据,随之找出这个第K大的数据。
-
灵魂函数 : getBitfield、setBitfield
如何将最大的k个数慢慢挑出来呢? 这里介绍一下灵魂函数getBitfield、setBitfield
作用:前者获取当前数据第i和第i+1位的数据(00,01,10,11按两位,其实就是对应基数排序的4个桶,step = RADIX_BITS),很容易理解。后者的作用是为了标记我们现在关注的数据,这个可能不太好理解,这里详细讲一下:
比如现在有1111,1110,1101,1100,0000,0001
这6个数而我们现在要找topk3。首先按照基数排序的思路
最大桶11
对应的元素数量为4:1111,1110,1101,1100
。此时最大的4
个数据大于3(topk3)也就是说我们只需从这4个数中继续找就行了,可以跳过当前剩余的入桶操作,这个setBitdield的作用就是记录这4个数的前缀11
帮助我们跳过其他数(之后只找前缀为11
的数)。在代码中具体表现为desired和desired_mask。这里简单提一下,后面会举例详细讲。
源码:
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
unsigne