一.GEMM算法概述
1.1不采用数据预取
首先,我们明确GEMM中的具体参数,取bm=128,bn=128,bk=8,rm=8,rn=8。当这几个参数选定后直观地感受一下这几个参数意义,假定给了三个矩阵,A,B,C,其维度都是2048*2048。要求解C=A*B。那么我们需要开启(2048/128)*(2048/128)=256个block,每个block里面有(128/8)*(128/8)=256个线程,每个线程需要负责计算C矩阵中8*8=64个元素的结果,每个block负责256*64=16384个元素的结果。
明确了上面的参数之后,我们来仔细地观察其中一个block的计算逻辑。对于这个block而言,bk=8,需要进行2048/8=256次迭代,我们先把这个迭代成为大迭代,每一次大迭代都需要把A里面128*8=1024个元素和B里面8*128=1024个元素先放到shared memory中。然后这个block中的256个线程把结果计算出来。计算完之后,再进入下一次大迭代。不断重复该过程,直至这个block负责的16384个元素的结果被求解出。大迭代示意图如下:
随后再具体看看每一次大迭代中,block中的线程的计算逻辑。在进行一个大迭代时,shared memory中有128*8=1024个A矩阵元素和8*128=1024个B矩阵元素。随后,每个线程需要进行8次迭代,我们把这个迭代称为小迭代。bk=8,所以有8次小迭代。每一次小迭代中,每个线程需要从shared memory中拿到A矩阵的一小列和B矩阵的一小行,即8个A元素和8个B的元素。线程将这8+8=16个元素放置在寄存器中。每个线程需要负责8*8=64个元素的计算,一共会产生64条FFMA指令。小迭代示意图如下:
以上就是不采用数据预取的GEMM算法计算逻辑。总的来说,对于一个block而言,有256个大迭代,每个大迭代中又有8个小迭代,这是后续内容的基础。
1.2 采用数据预取
差异体现在两方面,第一个是开启的shared memory和寄存器数量,第二个是需要提前将一些数据放置到shared memory和寄存器中。
为了实现数据预取,需要开启两倍的shared memory和寄存器。也可以将原来的shared memory切分成两块,也就是将bm*bk和bk*bn的矩阵一分为二。以A中的小矩阵而言,变成了两个bm*bk/2。然后大迭代次数由原来的256变成了512,称为数据预取或者双缓冲。在一个block中,原来在shared memory中需要存储的数据是bm*bk+bk*bn。现在变成了bm*bk*2+bk*bn*2。在一个thread中,为了存储A和B的数据,原来需要使用rm+rn个寄存器,现在需要使用2*(rm+rn)个寄存器。为了方便介绍,用read SM和write SM代表用来读写的两块共享内存,并用read REG和write REG来表示用来读写的两块寄存器。
把共享内存和寄存器说明白后,我们看具体的计算逻辑。在执行256次大迭代之前,我们需要提前将第0次大迭代的数据存到write SM中,并且将第0次小迭代的数据存到write REG中。在完成这一个预取过程后,我们再来仔细地看看第0个大迭代。需要注意的是,上一轮大迭代的write SM就是这一轮迭代的read SM。上一轮小迭代的write REG就是这一轮的read REG。所以在进行第0个大迭代时,上面的write SM就变成了read SM。我们首先需要将下一轮大迭代的数据存到write SM中。由于从global memory中取数的时钟周期非常多。所以在等待数据取回的同时,对read SM中的数据进行计算。也就是我们在等待的同时,需要开启8次小迭代来进行计算。而小迭代中也存在着读写分离,在对read REG进行计算之前,需要先执行write REG的操作,通过这种方式来掩盖访存的latency。整体逻辑如下:
for k in 256 big_loop:
prefecth next loop data to write_SM
// compute in read_SM
for iter in 8 small_loop:
prefecth next loop data to write_REG
compute in read_REG
采用数据预取的GEMM计算流程。核心思想:提前将下一轮迭代所需要的数据取出然后放置到更近的存储中,然后通过pipline的形式来减少访存的latency。
二.GEMM代码解析
由于将数据从global memory中搬运到shared memory中还经过了寄存器,所以对prefetch过程进行了细化。
2.1参数说明
BLOCK_SIZE_M、BLOCK_SIZE_K、BLOCK_SIZE_N分别代表上下文的bm、bk、bn。中间两个参数,THREAD_SIZE_Y、THREAD_SIZE_X代表rm、rn。最后的参数ENABLE_DOUBLE_BUFFER代表是否采用双缓冲,即是否采用数据预取 ,即开启双缓冲的情况。
template <
const int BLOCK_SIZE_M, // height of block of C that each thread block