DCU异构程序——GEMM

目录

一、概述

二、程序实现

三、编译运行


一、概述

        HIP属于显式编程模型,需要在程序中明确写出并行控制语句,包括数据传输、核函数启动等。核函数是运行在DCU上的函数,在CPU端运行的部分称为主机端(主要是执行管理和启动),DCU端运行的部分称为设备端(用于执行计算)。大概的流程如下图:

HIP程序流程

        ①主机端将需要并行计算的数据通过hipMemcpy()传递给DCU(将CPU存储的内容传递给DCU的显存);

        ②调用核函数启动函数hipLaunchKernelGGL()启动DCU,开始执行计算;

        ③设备端将计算好的结果数据通过hipMemcpy()从DCU复制回CPU。

        hipMemcpy()是阻塞式的,数据复制完成后才可以执行后续的程序;hipLanuchKernelGGL()是非阻塞式的,执行完后程序继续向后执行,但是在Kernel没有计算完成之前,最后一个hipMemcpy()是不会开始的,这是由于HIP的Stream机制。

二、程序实现

        下面是对GEMM的具体实现,GEMM.cpp:

#include <stdio.h>
#include <string.h>
#include "hip/hip_runtime.h"
#include <sys/time.h>

#define NUM 256
#define SCALAR_ZERO (double)(0)
#define LDS_NUM_ELEMENTS 2048
#define LDS_OFFSET_B 512
#define LDS_OFFSET_BLK 1024
#define DEPTH 8
#define MT0I 64
#define MT1J 64

#define MAC(A,B,DST) DST += A*B
#define TYPE_MAC(MULA,MULB,DST) DST = MAC(MULA,MULB,DST);
#define TT 4
#define SIZE_HALF_A 32
#define SIZE_HALF_B 32
#define SIZE_HALF_C 32

#define MAC_4x4\
    TYPE_MAC(rA[0],rB[0],rC[0]); \
    TYPE_MAC(rA[0],rB[1],rC[1]); \
    TYPE_MAC(rA[0],rB[2],rC[2]); \
    TYPE_MAC(rA[0],rB[3],rC[3]); \
    TYPE_MAC(rA[1],rB[0],rC[4]); \
    TYPE_MAC(rA[1],rB[1],rC[5]); \
    TYPE_MAC(rA[1],rB[2],rC[6]); \
    TYPE_MAC(rA[1],rB[3],rC[7]); \
    TYPE_MAC(rA[2],rB[0],rC[8]); \
    TYPE_MAC(rA[2],rB[1],rC[9]); \
    TYPE_MAC(rA[2],rB[2],rC[10]); \
    TYPE_MAC(rA[2],rB[3],rC[11]); \
    TYPE_MAC(rA[3],rB[0],rC[12]); \
    TYPE_MAC(rA[3],rB[1],rC[13]); \
    TYPE_MAC(rA[3],rB[2],rC[14]); \
    TYPE_MAC(rA[3],rB[3],rC[15]); \

#define MAC_4x4_BLK\
    TYPE_MAC(rA[0+TT],rB[0+TT],rC[0]); \
    TYPE_MAC(rA[0+TT],rB[1+TT],rC[1]); \
    TYPE_MAC(rA[0+TT],rB[2+TT],rC[2]); \
    TYPE_MAC(rA[0+TT],rB[3+TT],rC[3]); \
    TYPE_MAC(rA[1+TT],rB[0+TT],rC[4]); \
    TYPE_MAC(rA[1+TT],rB[1+TT],rC[5]); \
    TYPE_MAC(rA[1+TT],rB[2+TT],rC[6]); \
    TYPE_MAC(rA[1+TT],rB[3+TT],rC[7]); \
    TYPE_MAC(rA[2+TT],rB[0+TT],rC[8]); \
    TYPE_MAC(rA[2+TT],rB[1+TT],rC[9]); \
    TYPE_MAC(rA[2+TT],rB[2+TT],rC[10]); \
    TYPE_MAC(rA[2+TT],rB[3+TT],rC[11]); \
    TYPE_MAC(rA[3+TT],rB[0+TT],rC[12]); \
    TYPE_MAC(rA[3+TT],rB[1+TT],rC[13]); \
    TYPE_MAC(rA[3+TT],rB[2+TT],rC[14]); \
    TYPE_MAC(rA[3+TT],rB[3+TT],rC[15]); \

#define TYPE_MAC_WRITE(DST,SRC,ALPHA,REG,BETA) DST = 0 != (BETA) ? (ALPHA)*(REG) + (BETA)*(SRC) : (ALPHA)*(REG)

__device__ uint64_t inline readtime()
{
    uint64_t clock;
    asm volatile("s_waitcnt lgkmcnt(0)\n\t"
                "s_memtime %0\n\t"
                 "s_waitcnt lgkmcnt(0)\n\t"
                    : "=s" (clock));
    return clock;
}

__global__ void global_depth8_lds_2_bank(double *src_a, double *src_b, double *dst_c, double alpha, double beta, int size_m, int size_n, int size_k, uint64_t *clock_cycle)
{
    uint64_t cycle1 = readtime();
    __shared__ double localMemory[LDS_NUM_ELEMENTS];

    unsigned int serial = 0;
    unsigned int grj = (serial >> 5);
    unsigned int gri = (serial & 31);

    unsigned int goa = grj * size_m + gri * 2;
    unsigned int gob = grj * size_n + gri * 2;
    
    unsigned int lwa = serial * 2;
    unsigned int lwb = serial * 2 + LDS_OFFSET_B;

    double *local_write_A = localMemory + lwa;
    double *local_write_B = localMemory + lwb;

    unsigned int lrj =  (serial >> 4);
    unsigned int lri = (serial &15);

    unsigned int lra = lri * 2;
    unsigned int lrb = lrj * 2 + LDS_OFFSET_B;

    double *local_read_A = localMemory + lra;
    double *local_read_B = localMemory + lrb;

    unsigned int goc = (lri * size_n + lrj) * 2;
    double *global_address_C = dst_c + goc;
    double *global_address_A = src_a + goa;
    double *global_address_B = src_b + gob;

    int i,j,l;

    double rA[8], rB[8], rC[16];
    double global_a0, global_a1, global_b0, global_b1;

    rC[0] = SCALAR_ZERO;
    rC[1] = SCALAR_ZERO;
    rC[2] = SCALAR_ZERO;
    rC[3] = SCALAR_ZERO;
    rC[4] = SCALAR_ZERO;
    rC[5] = SCALAR_ZERO;
    rC[6] = SCALAR_ZERO;
    rC[7] = SCALAR_ZERO;
    rC[8] = SCALAR_ZERO;
    rC[9] = SCALAR_ZERO;
    rC[10] = SCALAR_ZERO;
    rC[11] = SCALAR_ZERO;
    rC[12] = SCALAR_ZERO;
    rC[13] = SCALAR_ZERO;
    rC[14] = SCALAR_ZERO;
    rC[15] = SCALAR_ZERO;

    global_a0 = *(global_address_A + 0);
    global_a1 = *(global_address_A + 1);
    global_b0 = *(global_address_B + 0);
    global_b1 = *(global_address_B + 1);

    global_address_A += DEPTH * size_m;
    global_address_B += DEPTH * size_n;
    
    *(local_write_A + 0) = global_a0;
    *(local_write_A + 1) = global_a1;
    *(local_write_B + 0) = global_b0;
    *(local_write_B + 1) = global_b1;

    lwa = (lwa + LDS_OFFSET_BLK) % LDS_NUM_ELEMENTS;
    lwb = (lwb + LDS_OFFSET_BLK) % LDS_NUM_ELEMENTS;
    
    local_write_A = localMemory + lwa;
    local_write_B = localMemory + lwb;

    __syncthreads();

    rA[0] = *(local_read_A + 0);
    rA[1] = *(local_read_A + 1);
    rA[2] = *(local_read_A + 0 + SIZE_HALF_A);
    rA[3] = *(local_read_A + 1 + SIZE_HALF_A);

    rB[0] = *(local_read_B + 0);
    rB[1] = *(local_read_B + 1);
    rB[2] = *(local_read_B + 0 + SIZE_HALF_B);
    rB[3] = *(local_read_B + 1 + SIZE_HALF_B);

    local_read_A += MT0I;
    local_read_B += MT1J;

    for(i = 0; i < size_k; i += 8)
    {
        global_a0 = *(global_address_A + 0);
        global_a1 = *(global_address_A + 1);
        global_b0 = *(global_address_B + 0);
        global_b1 = *(global_address_B + 1);
    
        global_address_A += DEPTH * size_m;
        global_address_B += DEPTH * size_n;

        rA[0+TT] = *(local_read_A + 0);
        rA[1+TT] = *(local_read_A + 1);
        rA[2+TT] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3+TT] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0+TT] = *(local_read_B + 0);
        rB[1+TT] = *(local_read_B + 1);
        rB[2+TT] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3+TT] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;
        MAC_4x4

        rA[0] = *(local_read_A + 0);
        rA[1] = *(local_read_A + 1);
        rA[2] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0] = *(local_read_B + 0);
        rB[1] = *(local_read_B + 1);
        rB[2] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;
        MAC_4x4_BLK

        rA[0+TT] = *(local_read_A + 0);
        rA[1+TT] = *(local_read_A + 1);
        rA[2+TT] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3+TT] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0+TT] = *(local_read_B + 0);
        rB[1+TT] = *(local_read_B + 1);
        rB[2+TT] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3+TT] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;
        MAC_4x4

        rA[0] = *(local_read_A + 0);
        rA[1] = *(local_read_A + 1);
        rA[2] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3] = *(local_read_A + 1 + SIZE_HALF_A);
    
        rB[0] = *(local_read_B + 0);
        rB[1] = *(local_read_B + 1);
        rB[2] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;
        MAC_4x4_BLK

        rA[0+TT] = *(local_read_A + 0);
        rA[1+TT] = *(local_read_A + 1);
        rA[2+TT] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3+TT] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0+TT] = *(local_read_B + 0);
        rB[1+TT] = *(local_read_B + 1);
        rB[2+TT] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3+TT] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;
        MAC_4x4

        rA[0] = *(local_read_A + 0);
        rA[1] = *(local_read_A + 1);
        rA[2] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0] = *(local_read_B + 0);
        rB[1] = *(local_read_B + 1);
        rB[2] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;
        MAC_4x4_BLK

        rA[0+TT] = *(local_read_A + 0);
        rA[1+TT] = *(local_read_A + 1);
        rA[2+TT] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3+TT] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0+TT] = *(local_read_B + 0);
        rB[1+TT] = *(local_read_B + 1);
        rB[2+TT] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3+TT] = *(local_read_B + 1 + SIZE_HALF_B);

        *(local_write_A + 0) = global_a0;
        *(local_write_A + 1) = global_a1;
        *(local_write_B + 0) = global_b0;
        *(local_write_B + 1) = global_b1;

        lwa = (lwa + LDS_OFFSET_BLK) % LDS_NUM_ELEMENTS;
        lwb = (lwa + LDS_OFFSET_BLK) % LDS_NUM_ELEMENTS;

        local_write_A = localMemory + lwa;
        local_write_B = localMemory + lwb;

        lra = (lra + LDS_OFFSET_BLK) % LDS_NUM_ELEMENTS;
        lrb = (lrb + LDS_OFFSET_BLK) % LDS_NUM_ELEMENTS;

        local_read_A = localMemory + lra;
        local_read_B = localMemory + lrb;
        MAC_4x4

        __syncthreads();

        rA[0] = *(local_read_A + 0);
        rA[1] = *(local_read_A + 1);
        rA[2] = *(local_read_A + 0 + SIZE_HALF_A);
        rA[3] = *(local_read_A + 1 + SIZE_HALF_A);

        rB[0] = *(local_read_B + 0);
        rB[1] = *(local_read_B + 1);
        rB[2] = *(local_read_B + 0 + SIZE_HALF_B);
        rB[3] = *(local_read_B + 1 + SIZE_HALF_B);

        local_read_A += MT0I;
        local_read_B += MT1J;

        MAC_4x4_BLK;
    }

    TYPE_MAC_WRITE(*(global_address_C+0), *(global_address_C+0), alpha, rC[0], beta);
	TYPE_MAC_WRITE(*(global_address_C+1), *(global_address_C+1), alpha, rC[1], beta);
	TYPE_MAC_WRITE(*(global_address_C+0+SIZE_HALF_C), *(global_address_C+0+SIZE_HALF_C), alpha, rC[2], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+SIZE_HALF_C), *(global_address_C+1+SIZE_HALF_C), alpha, rC[3], beta);
	
	TYPE_MAC_WRITE(*(global_address_C+0+size_n), *(global_address_C+0+size_n), alpha, rC[4], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+size_n), *(global_address_C+1+size_n), alpha, rC[5], beta);
	TYPE_MAC_WRITE(*(global_address_C+0+SIZE_HALF_C+size_n), *(global_address_C+0+SIZE_HALF_C+size_n), alpha, rC[6], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+SIZE_HALF_C+size_n), *(global_address_C+1+SIZE_HALF_C+size_n), alpha, rC[7], beta);
	
	TYPE_MAC_WRITE(*(global_address_C+0+size_n*SIZE_HALF_C), *(global_address_C+0+size_n*SIZE_HALF_C), alpha, rC[8], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+size_n*SIZE_HALF_C), *(global_address_C+1+size_n*SIZE_HALF_C), alpha, rC[9], beta);
	TYPE_MAC_WRITE(*(global_address_C+0+SIZE_HALF_C+size_n*SIZE_HALF_C), *(global_address_C+0+SIZE_HALF_C+size_n*SIZE_HALF_C), alpha, rC[10], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+SIZE_HALF_C+size_n*SIZE_HALF_C), *(global_address_C+1+SIZE_HALF_C+size_n*SIZE_HALF_C), alpha, rC[11], beta);
	
	TYPE_MAC_WRITE(*(global_address_C+0+size_n*(SIZE_HALF_C+1)), *(global_address_C+0+size_n*(SIZE_HALF_C+1)), alpha, rC[12], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+size_n*(SIZE_HALF_C+1)), *(global_address_C+1+size_n*(SIZE_HALF_C+1)), alpha, rC[13], beta);
	TYPE_MAC_WRITE(*(global_address_C+0+SIZE_HALF_C+size_n*(SIZE_HALF_C+1)), *(global_address_C+0+SIZE_HALF_C+size_n*(SIZE_HALF_C+1)), alpha, rC[14], beta);
	TYPE_MAC_WRITE(*(global_address_C+1+SIZE_HALF_C+size_n*(SIZE_HALF_C+1)), *(global_address_C+1+SIZE_HALF_C+size_n*(SIZE_HALF_C+1)), alpha, rC[15], beta);
	
    __syncthreads();
    *clock_cycle = readtime() - cycle1;
}

void mul_cpu(double *src_a, double *src_b, double *dst_c, double alpha, double beta, int size_m, int size_n, int size_k)
{
    int i, j, k;
    for(i = 0; i < size_m; i++)
    {
        for(j = 0; j < size_n; j++)
        {
            double sum = 0;
            for(k = 0; k < size_k; k++)
            {
                sum += src_a[k*size_m + i] * src_b[k*size_n + j];
            }
            dst_c[i*size_n + j] = alpha * sum + beta * dst_c[i*size_n + j];
        }
    }
}

int main(int argc, char *argv[])
{
    double *src_a, *src_b, *out_cpu;
    double *a_device, *b_device, *c_device, *out_gpu;
    double alpha = 2.0, beta = 3.0;

    int size_m = 64, size_n = 64, size_k = 128;
    int m = NUM;
    int i, error = 0;
    uint64_t *clock_cycle;
    uint64_t time[1];

    src_a = (double *)malloc(size_m * size_k * sizeof(double));
    src_b = (double *)malloc(size_k * size_n * sizeof(double));
    out_cpu = (double *)malloc(size_m * size_n * sizeof(double));
    out_gpu = (double *)malloc(size_m * size_n * sizeof(double));

    hipMalloc((void**)&a_device, size_m * size_k * sizeof(double));
    hipMalloc((void**)&b_device, size_k * size_n * sizeof(double));
    hipMalloc((void**)&c_device, size_m * size_n * sizeof(double));
    hipMalloc((void**)&clock_cycle, 1 * sizeof(uint64_t));

    for(i = 0; i < size_m * size_k; i++)
    {
        src_a[i] = rand()%128;
    }
    for(i = 0; i < size_k * size_n; i++)
    {
        src_b[i] = rand()%128;
    }
    for(i = 0; i < size_m * size_n; i++)
    {
        out_gpu[i] = out_cpu[i] = rand()%128;
    }

    hipInit(0);
    hipDevice_t device;
    hipCtx_t context;
    hipDeviceGet(&device, 0);
    hipCtxCreate(&context, 0, device);

    hipMemcpy(a_device, src_a, size_m * size_k * sizeof(double), hipMemcpyHostToDevice);
    hipMemcpy(b_device, src_b, size_n * size_k * sizeof(double), hipMemcpyHostToDevice);
    hipMemcpy(c_device, out_cpu, size_m * size_n * sizeof(double), hipMemcpyHostToDevice);
    hipLaunchKernelGGL(global_depth8_lds_2_bank, dim3(1,1,1), dim3(256,1,1), 0, 0, a_device, b_device, c_device, alpha, beta, size_m, size_n, size_k, clock_cycle);
    hipMemcpy(time, clock_cycle, 1 * sizeof(uint64_t), hipMemcpyDeviceToHost);
    hipMemcpy(out_gpu, c_device, size_n * size_m * sizeof(double), hipMemcpyDeviceToHost);
    mul_cpu(src_a, src_b, out_cpu, alpha, beta, size_m, size_n, size_k);

    for(i = 0; i < size_n * size_m; i++)
    {
        if(fabs(out_gpu[i] - out_cpu[i]) > 1e-6)
        {
            error++;
            printf("%d, %lf, %lf\n", i, out_gpu[i], out_cpu[i]);
        }
    }

    if(error == 0)
        printf("**********result is ok!**********\n");

    double run_time = (double)time[0]/(1443.0*1000);
    printf("clock %lld, time is %f (ms), GFlops is %f GFlops\n\n", time[0], run_time, (double)(2*size_m*size_n*size_k)/(run_time*1000000));

    printf("Finish\n");

    return 0;
}

三、编译运行

        HIP程序采用hipcc编译。

运行结果:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猿核试Bug愁

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值