OpenCL(sgemm)

例程介绍

在OpenCL设备端进行单精度矩阵乘法运算,并与利用CBLAS库的运算结果进行比较。仅关注与OpenCL或TI设备相关的代码,其他算法逻辑不深入研究。

例程源码

Host端源码

//main.cpp
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <math.h>
#include <signal.h>
#include <ocl_util.h>
#include "kernel.dsp_h"

#ifdef _TI_RTOS
#include "../rtos_main.c"
    #if ti_sysbios_BIOS_version <= (0x65200)
    #include <ti/sysbios/posix/time.h>
    #else
    #include <ti/posix/gcc/time.h>
    #endif
#else
#include <time.h>
#endif

extern "C" {
#include "cblas.h"
}

using namespace std;
using namespace cl;

static double clock_diff (struct timespec *t1, struct timespec *t2)
{ return t2->tv_sec - t1->tv_sec + (t2->tv_nsec - t1->tv_nsec) / 1e9; }

void PrintMatrix(float *mat, int rows, int cols, enum CBLAS_ORDER mem_order);
void MatmulHost_ATLAS(enum CBLAS_ORDER mem_order,
                const float* A, const float *B, float *C, int M, int N, int K,
                float alpha, float beta);
void MatmulHost_loopnest(enum CBLAS_ORDER mem_order,
                const float* A, const float *B, float *C, int M, int N, int K,
                float alpha, float beta);
int  CheckForErrors(const float *Mat, const float *Golden, int M, int N, int K,
                    enum CBLAS_ORDER mem_order);

/* ======================================================================== */
/*  Global Variables                                                        */
/* ======================================================================== */
float alpha                         = 1.0f;
float beta                          = 0.0f;
enum CBLAS_ORDER order              = CblasColMajor;
#ifndef _TI_RTOS
bool check                          = true;
#else
bool check                          = false;
#endif
bool random_in                      = false;
bool calc_check                     = false;

int M                               = 0;
int N                               = 0;
int K                               = 0;
int L2_BUF_SIZE                     = 0;
int MSMC_BUF_SIZE                   = 0;
int NUMAPANELS                      = 0;
int NUMBPANELS                      = 0;
int NUMCOMPUNITS                    = 0;

/* ======================================================================== */
/*  Function Headers                                                        */
/* ======================================================================== */
void PrintUsageAndExit();
void HandleOptions(int argc, char* argv[]);
void SetSgemmParams(Device& device);

/* ======================================================================== */
/*  MAIN                                                                    */
/* ======================================================================== */
#ifdef _TI_RTOS
void ocl_main(UArg arg0, UArg arg1)
{
   int    argc = (int)     arg0;
   char **argv = (char **) arg1;
#else
#define RETURN(x) return x
int main(int argc, char *argv[])
{
#endif
   /*-------------------------------------------------------------------------
   * Catch ctrl-c so we ensure that we call dtors and the dsp is reset properly
   *------------------------------------------------------------------------*/
   signal(SIGABRT, exit);
   signal(SIGTERM, exit);

   int errs = 0;

    /* ---------------------------------------------------------------- */
    /*  Handle command line options to get M,N,K                        */
    /* ---------------------------------------------------------------- */
    HandleOptions(argc,argv);

    try
    {
       Context context(CL_DEVICE_TYPE_ACCELERATOR);
       std::vector<Device> devices = context.getInfo<CL_CONTEXT_DEVICES>();
       CommandQueue Q(context, devices[0]);

       /*---------------------------------------------------------------------
       * Determine platform, set sgemm blocking/tiling parameters
       *--------------------------------------------------------------------*/
       SetSgemmParams(devices[0]);
       if (NUMCOMPUNITS == 0)  RETURN(-1);

       int VALRANGE = 17;
       if (random_in)
       {
           srand(time(NULL));
           alpha = (float) (rand() % VALRANGE + 1);
           beta  = (float) (rand() % VALRANGE + 1);
       }

       printf("C[%d,%d] = alpha * A[%d,%d] * B[%d,%d] + beta * C[%d,%d], %s\n",
              M,N,M,K,K,N, M,N,
              (order == CblasRowMajor ? "use row-major storage"
                                      : "use col-major storage"));
       printf("alpha=%f, beta=%f\n\n", alpha, beta);

       double dsp_elapsed = 0;
       double total_GFLOP = 2.0*M*N*K*1.0e-9;

       /*---------------------------------------------------------------------
       * Build kernel from pre-compiled and embedded binary
       *--------------------------------------------------------------------*/
       Program::Binaries binary(1, make_pair(kernel_dsp_bin,
                                             sizeof(kernel_dsp_bin)));
       Program             program = Program(context, devices, binary);
       program.build(devices);

       /* ---------------------------------------------------------------- */
       /*  Allocate Buffers:                                               */
       /* ---------------------------------------------------------------- */
       Buffer bufA   (context, CL_MEM_READ_ONLY,  M*K*sizeof(float));
       Buffer bufB   (context, CL_MEM_READ_ONLY,  K*N*sizeof(float));
       Buffer bufC   (context, CL_MEM_READ_WRITE, M*N*sizeof(float));
       Buffer *bufMsmc = NULL;
       if (MSMC_BUF_SIZE != 0)
           bufMsmc = new Buffer(context, CL_MEM_READ_WRITE|CL_MEM_USE_MSMC_TI,
                                MSMC_BUF_SIZE);
       else
           bufMsmc = new Buffer(context, CL_MEM_READ_WRITE, 4); // dummy one

       /* ---------------------------------------------------------------- */
       /*  Initialized input arrays with random test data.                 */
       /* ---------------------------------------------------------------- */
       float *A = (float*) Q.enqueueMapBuffer(bufA, CL_TRUE, CL_MAP_WRITE, 0,
                                              M*K*sizeof(float));
       float *B = (float*) Q.enqueueMapBuffer(bufB, CL_TRUE, CL_MAP_WRITE, 0,
                                              K*N*sizeof(float));
       float *C = (float*) Q.enqueueMapBuffer(bufC, CL_TRUE, CL_MAP_WRITE, 0,
                                              M*N*sizeof(float));
       float *gold = nullptr;

       cout << "Generating Input Data ..." << flush;
       for (int i = 0; i < M*K; ++i)
           A[i] = random_in ? (float)(rand() % VALRANGE + 1) : 1 + (i & 7);
       for (int i = 0; i < K*N; ++i)
           B[i] = random_in ? (float)(rand() % VALRANGE + 1) : 1 + (i & 11);
       for (int i = 0; i < M*N; ++i)
           C[i] = random_in ? (float)(rand() % VALRANGE + 1) : 1 + (i & 5);
       cout << "Complete" << endl;

       if (check)
       {
#ifndef _TI_RTOS
           if ((gold = (float*) malloc(M*N*sizeof(float))) == NULL)
#else
           if ((gold = (float*) __malloc_ddr(M*N*sizeof(float))) == NULL)
#endif
           {
               printf("Unable to allocate memory to verify results\n");
               exit(-1);
           }
           memcpy(gold, C, M*N*sizeof(float));
       }

       PrintMatrix(A,M,K,order);
       PrintMatrix(B,K,N,order);
       if (random_in)  PrintMatrix(C,M,N,order);

       Q.enqueueUnmapMemObject(bufA, A);
       Q.enqueueUnmapMemObject(bufB, B);
       Q.enqueueUnmapMemObject(bufC, C);
       Q.finish();

       /*----------------------------------------------------------------------
       * Device: Do A*B = C
       *---------------------------------------------------------------------*/
       Kernel kernel (program, "K_ocl_sgemm_dsp");
       KernelFunctor matmul = kernel.bind(Q, NDRange(NUMCOMPUNITS), NDRange(1));

       struct timespec t0,t1;
       clock_gettime(CLOCK_MONOTONIC, &t0);

       /*----------------------------------------------------------------------
       * Convert RowMajor computation to ColumnMajor computation
       * Fact: Mem_Layout(A_RowMajor) === Mem_Layout(Transpose(A)_ColMajor)
       * Therefore: C_RowMajor = A_RowMajor * B_RowMajor
       *            C[mxn] = A[mxk] * B[kxn]
       * can be computed as:
       * Transpose(C)_ColMajor = Transpose(B)_ColMajor * Transpose(A)_ColMajor
       * C'[nxm] = B'[nxk] * A'[kxm],
       * where ptrC' === ptrC, ptrA' === ptrA, ptrB' === ptrB
       * So, all we need to do is to: swap(m, n), swap(a, b)
       * ld_Transpose(a)_col = lda_row = k,
       * ld_Transpose(b)_col = ldb_row = n,
       * ld_Transpose(c)_col = ldc_row = n,
       *---------------------------------------------------------------------*/
       if (order == CblasRowMajor)
           matmul(N, M, K, alpha, bufB, N, bufA, K, beta, bufC, N,
                  NUMAPANELS, NUMBPANELS,
                  __local(L2_BUF_SIZE), *bufMsmc).wait();
       else
           matmul(M, N, K, alpha, bufA, M, bufB, K, beta, bufC, M,
                  NUMAPANELS, NUMBPANELS,
                  __local(L2_BUF_SIZE), *bufMsmc).wait();

       clock_gettime(CLOCK_MONOTONIC, &t1);
       dsp_elapsed = clock_diff (&t0, &t1);

       double gflops = total_GFLOP/dsp_elapsed;
       printf("%4d DSPs: %.3f Gflops (%.6f s) \n",
              NUMCOMPUNITS, gflops, dsp_elapsed);

       if (bufMsmc != NULL)  delete bufMsmc;

       /*----------------------------------------------------------------------
       * If checking results against a host matmul.
       * This can be time consuming for large sizes.
       *---------------------------------------------------------------------*/
       C = (float*) Q.enqueueMapBuffer(bufC, CL_TRUE, CL_MAP_READ, 0,
                                       M*N*sizeof(float));

       if (check)
       {
           A = (float*) Q.enqueueMapBuffer(bufA, CL_TRUE, CL_MAP_READ, 0,
                                           M*K*sizeof(float));
           B = (float*) Q.enqueueMapBuffer(bufB, CL_TRUE, CL_MAP_READ, 0,
                                           K*N*sizeof(float));
           clock_gettime(CLOCK_MONOTONIC, &t0);

           MatmulHost_ATLAS(order, A, B, gold, M, N, K, alpha, beta);

           clock_gettime(CLOCK_MONOTONIC, &t1);
           dsp_elapsed = clock_diff (&t0, &t1);

           double gflops = total_GFLOP/dsp_elapsed;
           printf("   1 CPU : %.3f Gflops (%.6f s) with ATLAS library\n",
                  gflops, dsp_elapsed);
           Q.enqueueUnmapMemObject(bufA, A);
           Q.enqueueUnmapMemObject(bufB, B);
	   PrintMatrix(gold,M,N,order);
	   errs = CheckForErrors(C, gold, M, N, K, order);
#ifndef _TI_RTOS
           free(gold);
#else
           __free_ddr(gold);
#endif
       }

       PrintMatrix(C,M,N,order);
       Q.enqueueUnmapMemObject(bufC, C); Q.finish();
   }
   catch (Error& err)
   {
       cerr << "ERROR: " << err.what() << "(" << err.err() << ", "
            << ocl_decode_error(err.err()) << ")" << endl;
       exit(-1);
   }

   RETURN(errs);
}

/******************************************************************************
* Supporting Functions
******************************************************************************/
void PrintUsageAndExit()
{
    cout << "Matrix C[M,N] = A[M,K] * B [K,N]" << endl
         << "Default value of M,N,K is " << M << endl
         << "Usage: sgemm [options] " << endl
         << "Options: " << endl
         << "-M arg : Number of rows for array C and A" << endl
         << "-K arg : Number of cols for array A, rows for array B" << endl
         << "-N arg : Number of cols for array C and B" << endl
         << "-d     : Do not check results against host computation" << endl
         << "-r     : Generate random inputs" << endl
         << "-or    : Use Row-Major storage (default is Col-Major)" << endl
         << "-h     : Show help message"
         << endl;
    exit(0);
}

void HandleOptions(int argc, char* argv[])
{
    int c;

    if (argc == 1) return;

    while ((c = getopt (argc, argv, "o:M:K:N:hdrx")) != -1)
        switch(c)
        {
            case 'o': order = (*optarg == 'r') ? CblasRowMajor
                                               : CblasColMajor; break;
            case 'M': M = abs(atoi(optarg)); break;
            case 'K': K = abs(atoi(optarg)); break;
            case 'N': N = abs(atoi(optarg)); break;
            case 'h': PrintUsageAndExit();   break;
            case 'd': check = false;         break;
            case 'r': random_in = true;      break;
            case 'x': calc_check = true;     break;
            default:  PrintUsageAndExit();
        }
}

void PrintMatrix(float *mat, int rows, int cols, enum CBLAS_ORDER mem_order)
{
    if (rows > 64) return;
    if (cols > 16) return;

    int index;
    for (int i=0; i<rows; i++)
    {
        for (int j=0; j<cols; j++)
        {
            if (mem_order == CblasRowMajor) index = i*cols + j;
            else                            index = j*rows + i;

            cout << setprecision(9) << setw(10)
                 << mat[index] << " ";
        }
        cout << endl;
    }
    cout << endl;
}

#define EPISILON 0.01  // we have all integer inputs
int CheckForErrors(const float *Mat, const float *Golden, int M, int N, int K,
                    enum CBLAS_ORDER mem_order)
{
    int       num_errors = 0, i, j;
    const int print_nerrors = 13;
    int       index;

    for (i=0; i<M; i++)
        for (j=0; j<N; j++)
        {
            if (mem_order == CblasRowMajor) index = i*N + j;
            else                            index = j*M + i;

            float delta = Golden[index] - Mat[index];

            if (delta < -EPISILON || delta > EPISILON)
                if ((num_errors += 1) < print_nerrors)
                    printf("Error [%d,%d]: %f <==> %f\n", i, j,
                           Golden[index], Mat[index]);
        }

    if (num_errors > 0)
         cout << "FAIL with " << num_errors << " errors!" << endl;
    else cout << "PASS!" << endl;
    return num_errors;
}

void MatmulHost_ATLAS(enum CBLAS_ORDER mem_order,
                const float*A, const float *B, float *C, int M, int N, int K,
                float alpha, float beta)
{
#ifndef _TI_RTOS
    if (mem_order == CblasRowMajor)
    {
        cblas_sgemm(mem_order, CblasNoTrans, CblasNoTrans,
                    M, N, K, alpha,
                    A, /* lda = */ K,
                    B, /* ldb = */ N,
                    beta,
                    C, /* ldc = */ N
                   );
    } else {
        cblas_sgemm(mem_order, CblasNoTrans, CblasNoTrans,
                    M, N, K, alpha,
                    A, /* lda = */ M,
                    B, /* ldb = */ K,
                    beta,
                    C, /* ldc = */ M
                   );
    }
#endif
}

static cl_ulong roundDownPower2(cl_ulong value)
{ return (value == 0) ? 0 :  1 << ilogb(value); }

/*-----------------------------------------------------------------------------
* Check platform name, set sgemm blocking/tiling parameters accordingly
*----------------------------------------------------------------------------*/
void SetSgemmParams(Device& device)
{
   int APanelSz        = 8  << 10;
   int BPanelSz        = 16 << 10;
   cl_ulong global_mem = 0;
   cl_ulong l2_mem     = 0;
   cl_ulong msmc_mem   = 0;

   device.getInfo(CL_DEVICE_MAX_COMPUTE_UNITS, &NUMCOMPUNITS);
   device.getInfo(CL_DEVICE_GLOBAL_MEM_SIZE,   &global_mem);
   device.getInfo(CL_DEVICE_LOCAL_MEM_SIZE,    &l2_mem);

#ifdef CL_DEVICE_MSMC_MEM_SIZE_TI
   device.getInfo(CL_DEVICE_MSMC_MEM_SIZE_TI,  &msmc_mem);
#endif

   global_mem    = roundDownPower2(global_mem);
   L2_BUF_SIZE   = roundDownPower2(l2_mem);
   MSMC_BUF_SIZE = roundDownPower2(msmc_mem);

   /*----------------------------------------------------------------------
   * How big of a square matrix can we use.  Need 3 BTW
   *---------------------------------------------------------------------*/
   if (!M && !N && !K)
   {
       M = N = K = roundDownPower2(sqrt(global_mem / 3 / sizeof(float)));
       if (M >= 2048) M = N = K = 2048;
   }

   NUMAPANELS    = L2_BUF_SIZE / 2 / APanelSz;
   NUMBPANELS    = L2_BUF_SIZE / 4 / BPanelSz;

   if ((NUMCOMPUNITS * APanelSz * NUMAPANELS) > MSMC_BUF_SIZE)
        MSMC_BUF_SIZE = 0;
   else MSMC_BUF_SIZE = NUMCOMPUNITS * APanelSz * NUMAPANELS;

   if (calc_check)
   {
       cout << "M,N,K         = " << M << ", "
            << N << ", "
            << K << endl;
       cout << "MSMC_BUF_SIZE = " << MSMC_BUF_SIZE << endl;
       cout << "L2_BUF_SIZE   = " << L2_BUF_SIZE << endl;
       cout << "NUMAPANELS    = " << NUMAPANELS << endl;
       cout << "NUMBPANELS    = " << NUMBPANELS << endl;
   }
}

OpenCL设备端源码

#include "data.h"
#include "cblas.h"

//矩阵乘法api
void sgemm(
           int m, int n, int k,
           float alpha,
	   global float * a, int lda,
           global float * b, int ldb,
           float beta,
           global float * c, int ldc,
           int NUMAPANELS, int NUMBPANELS,
           float* pL1, local float* pL2, global float* pMsmc, int tid);

kernel __attribute__((reqd_work_group_size(1,1,1))) void 
K_ocl_sgemm_dsp(
                int m, int n, int k,
                float alpha,
                global float *a, int lda,
                global float *b, int ldb,
                float beta,
                global float *c, int ldc,
                int NUMAPANELS, int NUMBPANELS,
                local  float *L2_buf, global float *Msmc_buf)
{
    int chunks    = get_global_size(0);
    int id        = get_global_id(0);

    int mLocal    = m < chunks ? 1 : m / chunks;
    /* if not enough work for all cores, only first (chunks) cores compute */
    if (m < chunks && id >= m)  return;

    int offset    = mLocal * id;
    /* if m > chunks and (chunks) does not divide (m) evenly,
     * first (m % chunks) cores get one extra row to compute */
    if (m > chunks)
    {
        mLocal += (id < (m % chunks) ? 1 : 0);
        offset += (id < (m % chunks) ? id : (m % chunks));
    }

    float* L1_buf = (float*) __scratch_l1d_start(); //返回可作为暂存内存的L1D SRAM基地址。
    if (Msmc_buf >= (global float*) 0x80000000)  Msmc_buf = (global float*) 0;
    __cache_l1d_16k(); //设置L1D内存为16k缓存,剩下的作为暂存内存
    sgemm(mLocal, n, k,
          alpha, a + offset, lda, b, ldb, beta, c + offset, ldc,
          NUMAPANELS, NUMBPANELS,
          L1_buf, L2_buf, Msmc_buf, __local_core_num());
    __cache_l1d_all(); //L1D内存全部设为缓存
}

## 其他源文件

```cpp
//sgemm.cpp
#define USE_EDMA 1

#include "data.h"
#include "sgemm_kernel.h"
#include <string.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include "dsp_c.h"
#include "dsp_edmamgr.h"

/*-----------------------------------------------------------------------------
* On KeyStone devices, MSMC is not cached in L2 and this macro could be 
* #define DSP_wbInv_L2  __cache_l1d_flush.  However, the performance delta 
* is negligible.  On Sitara (AM57) devices, the MSMC or (OCMC) is cached in 
* L2 as well as L1D.
*----------------------------------------------------------------------------*/
#define DSP_wbInv_L2  __cache_l2_flush //清空L2缓冲区
#define DSP_wbInv_L1D __cache_l1d_flush //清空L1D

// sgemm interface
// Assuming COLUMN MAJOR, NO TRANPOSE ON A AND B
// Computes C[MxN] = alpha * A[MxK] * B[KxN] + beta * C[MxN]
// Requirements: 16KB  of L1 SRAM, passed in as pL1
// Requirements: 128KB of L2 SRAM, passed in as pL2
void sgemm(
           const int m, const int n, const int k,
           const float  alpha,
           float* restrict a, const int lda,
           float* restrict b, const int ldb,
           const float  beta,
           float* restrict c, const int ldc,
           int NUMAPANELS, int NUMBPANELS,
           float* restrict pL1, float* restrict pL2,
           float* restrict pMsmc, int tid
          )
{
    if (m == 0 || n == 0 || ((alpha == 0.0f || k == 0) && beta == 1.0f)) return;

    float __attribute__((aligned(8)))
          ptrCTemp[CORE_PROCESS_ROWS*CORE_PROCESS_COLS];
    int mXferIndex, nXferIndex;
    int kIndex, kCnt, kCntNext;
    int mIndex, mCnt, mCntNext;
    int nIndex, nCnt, nCntNext/*, innerNCnt*/;
#if !(USE_EDMA)
    int kCntPrev, nCntPrev, mCntPrev;
#endif
    int innerIndex_m, innerIndex_n;
    int flagLastK, flagLastM, flagLastN, flagLastMXfers, flagLastNXfers;
    float * restrict ptrA, * restrict ptrB, * restrict ptrC;
    float * restrict ptrASeg1, * restrict ptrASeg2;
    float * restrict ptrBSeg1, * restrict ptrBSeg2;
    short  indexACurrent, indexANext, indexBCurrent, indexBNext;
    float * restrict ptrCInternal;
    int ldcInternal, i, j, nValid, mValid;

    // partition in m dimension
    int MPARTITION = (NUMAPANELS*CORE_PROCESS_ROWS);
    // partition in n dimension
    int NPARTITION = (NUMBPANELS*CORE_PROCESS_COLS);
    
    if (pMsmc)
    {
        // Keep unpacked A panels in MSMC SRAM
        ptrASeg1 = pMsmc+(tid)*MPARTITION*KPARTITION;
    }
    // Move packed A panel in L2 from DDR
    ptrASeg2 = pL2; 
    // Keep B panel ping pong buffers in L2 
    ptrBSeg1 = ptrASeg2+(MPARTITION*KPARTITION);
    ptrBSeg2 = ptrBSeg1+(NPARTITION*KPARTITION); 

    /* Beta scaling of C */
    if(beta != (float) 1.0) // only if scaling of C is needed
    {
        if(beta == (float) 0.0)
        {
            // zero out c column by column
            for(nCnt = 0; nCnt < n; nCnt++)
                memset(c + nCnt*ldc, 0, m * sizeof(float));
        } // if(beta==0.0f)
        else
        {
            // column by column multiplication
            for(nCnt = 0; nCnt < n; nCnt++)
                for(mCnt = 0; mCnt < m; mCnt++)
                    c[nCnt*ldc + mCnt] *= beta;
        } // else
    } // if(beta != 1.0f)

    mXferIndex = 0;
    nXferIndex = 0;

    mCnt = (m < MPARTITION) ? m : MPARTITION;
    kCnt = (k < KPARTITION) ? k : KPARTITION;
    nCnt = (n < NPARTITION) ? n : NPARTITION;

#if USE_EDMA
    /* Initialize EDMA Manager */
    EdmaMgr_Handle chan0, chan1;
    if (pMsmc != NULL)  chan0 = __ocl_EdmaMgr_alloc_intrakernel(1);
    chan1 = __ocl_EdmaMgr_alloc_intrakernel(1);
    if ((pMsmc != NULL && !chan0) || !chan1) 
    {  
        printf("Failed to alloc edma handle.\n");
    }
#endif

    if (pMsmc)
    {
        // initiate first transfer of A to MSMC
#if USE_EDMA
        EdmaMgr_copy2D2DSep(chan0,
                            a, /* src */
                            ptrASeg1, /* dst */
                            mCnt*sizeof(float), /* num_bytes */
                            kCnt, /* num_lines */
                            lda*sizeof(float), /* src_pitch */
                            MPARTITION*sizeof(float) /* dst_pitch */
                            ); //一次DMA操作
        DSP_wbInv_L2();  // ptrASeg1
#else
        for (i = 0; i < kCnt; i++)
            memcpy(ptrASeg1 + i * MPARTITION, a + i * lda, mCnt*sizeof(float));
#endif
    }

    // initiate first transfer of B to L2
#if USE_EDMA
    EdmaMgr_copy2D2DSep(chan1,
                        b, /* src */
                        ptrBSeg1, /* dst */
                        kCnt*sizeof(float), /* num_bytes */
                        nCnt, /* num_lines */
                        ldb*sizeof(float), /* src_pitch */
                        KPARTITION*sizeof(float) /* dst_pitch */
                        );
    DSP_wbInv_L1D();  // ptrBSeg1
#else
    for (i = 0; i < nCnt; i++)
        memcpy(ptrBSeg1 + i * KPARTITION, b + i * ldb, kCnt*sizeof(float));
#endif

    indexACurrent=1;
    indexANext=0;
    indexBCurrent=1;
    indexBNext=0;

    for(kIndex=0; kIndex<k; kIndex+=KPARTITION)  // partition in k dimension
    {
        nXferIndex = kIndex;
        kCnt = ((k-kIndex) < KPARTITION) ? (k-kIndex) : KPARTITION;
        kCntNext = ((k-kIndex-KPARTITION) < KPARTITION) ? (k-kIndex-KPARTITION) : KPARTITION;
        flagLastK = ((kIndex+KPARTITION) < k) ? 0 : 1;

        for(mIndex = 0; mIndex<m; mIndex+=MPARTITION)  // partition in m dimension
        {
            mCnt = ((m-mIndex) < MPARTITION) ? (m-mIndex) : MPARTITION;
            flagLastM = ((mIndex+MPARTITION)<m) ? 0 : 1;
            flagLastMXfers = ((mIndex+2*MPARTITION)<m) ? 0 : 1;
            mCntNext = ((m-mIndex-MPARTITION) < MPARTITION) ?
                       (m-mIndex-MPARTITION) : MPARTITION;
            mCntNext = (mCntNext <= 0) ? (m < MPARTITION ? m : MPARTITION)
                                       :  mCntNext;
            if(flagLastM) mCntNext = (m < MPARTITION) ? m : MPARTITION;

            // bring in A into MSMC SRAM (a new parallel transfer)
            indexACurrent = (indexACurrent+1) & 1;
            indexANext = (indexANext+1) & 1;

            // No need to memset invalid rows, because we select results
            // memset((void *) ptrASeg2, 0,
            //        MPARTITION * KPARTITION * sizeof(float));
            if (pMsmc)
            {
#if USE_EDMA
                EdmaMgr_wait(chan0); //等待与句柄相关的DMA操作完成
#endif
                dataMoveA(ptrASeg2, ptrASeg1, mCnt, kCnt, MPARTITION);
            }
            else
            {
                dataMoveA(ptrASeg2, a+mXferIndex, mCnt, kCnt, lda);
            }

            mXferIndex += mCnt;
            mXferIndex = (!flagLastM) ? mXferIndex: mXferIndex-m+kCnt*lda;

            if (pMsmc)
            {
                if ((!flagLastM) || (!flagLastK))
                {
                    if (mIndex == 0 || flagLastMXfers)
                    {
#if USE_EDMA
                        EdmaMgr_copy2D2DSep(chan0,
                                            a+mXferIndex, /* src */
                                            ptrASeg1, /* dst */
                                            mCntNext*sizeof(float), /* num_bytes */
                                            (flagLastM ? kCntNext : kCnt), /* num_lines */
                                            lda*sizeof(float), /* src_pitch */
                                            MPARTITION*sizeof(float) /* dst_pitch */
                                            );
                        DSP_wbInv_L2();  // ptrASeg1
#else
                        kCntPrev = (flagLastM ? kCntNext : kCnt);
                        for (i = 0; i < kCntPrev; i++)
                            memcpy(ptrASeg1 + i * MPARTITION,
                                   a+mXferIndex + i * lda,
                                   mCntNext*sizeof(float));
                        mCntPrev = mCntNext;
#endif
                    }
                    else if (flagLastM)
                    {
#if USE_EDMA
                        EdmaMgr_copy2D2DSep(chan0,
                                            a+mXferIndex, /* src */
                                            ptrASeg1, /* dst */
                                            mCntNext*sizeof(float), /* num_bytes */
                                            kCntNext, /* num_lines */
                                            lda*sizeof(float), /* src_pitch */
                                            MPARTITION*sizeof(float) /* dst_pitch */
                                            );
                        DSP_wbInv_L2();  // ptrASeg1
#else
                        kCntPrev = kCntNext;
                        for (i = 0; i < kCntPrev; i++)
                            memcpy(ptrASeg1 + i * MPARTITION,
                                   a+mXferIndex + i * lda,
                                   mCntNext*sizeof(float));
                        mCntPrev = mCntNext;
#endif
                    }
                    else
                    {
#if USE_EDMA
                        EdmaMgr_copyFast(chan0,
                                         a+mXferIndex, /* src */
                                         ptrASeg1 /* dst */
                                         ); //设置不同的起终点,使用该句柄上次指令的其他参数
                        DSP_wbInv_L2();  // ptrASeg1
#else
                        for (i = 0; i < kCntPrev; i++)
                            memcpy(ptrASeg1 + i * MPARTITION,
                                   a+mXferIndex + i * lda,
                                   mCntPrev*sizeof(float));
#endif
                    }
                }
            }


            for(nIndex = 0; nIndex<n; nIndex+=NPARTITION)  // partition in n dimension
            {
                nCnt = ((n-nIndex) < NPARTITION) ? (n-nIndex) : NPARTITION;
                nCntNext = ((n-nIndex-NPARTITION) < NPARTITION) ? (n-nIndex-NPARTITION) : NPARTITION;
                nCntNext = (nCntNext <= 0) ? (n < NPARTITION ? n : NPARTITION)
                                           : nCntNext;
                flagLastN = ((nIndex+NPARTITION)<n) ? 0 : 1;
                flagLastNXfers = ((nIndex+2*NPARTITION)<n) ? 0 : 1;
                if(flagLastN) nCntNext = (n < NPARTITION) ? n : NPARTITION;

                // bring in B into L1 SRAM (a new parallel transfer)
                indexBCurrent = (indexBCurrent+1) & 1;
                indexBNext = (indexBNext+1) & 1;
                
#if USE_EDMA
                EdmaMgr_wait(chan1);
#endif

                if((!flagLastM) || (!flagLastK) || (!flagLastN)) // don't carry out DMA for the last iteration
                {
                    nXferIndex += nCnt*ldb;
                    nXferIndex = (!flagLastN) ? nXferIndex: kIndex;
                    nXferIndex = ((!flagLastN) || (!flagLastM)) ? nXferIndex: (kIndex+kCnt);
                    ptrB = (indexBNext == 0) ? ptrBSeg1: ptrBSeg2;
                    if (nIndex == 0 || flagLastNXfers)
                    {
#if USE_EDMA
                        EdmaMgr_copy2D2DSep(chan1,
                                            b+nXferIndex, /* src */
                                            ptrB, /* dst */
                                            ((flagLastM && flagLastN) ? kCntNext : kCnt)*sizeof(float), /* num_bytes */
                                            nCntNext, /* num_lines */
                                            ldb*sizeof(float), /* src_pitch */
                                            KPARTITION*sizeof(float) /* dst_pitch */
                                            );
                        DSP_wbInv_L1D();  // ptrB
#else
                        for (i = 0; i < nCntNext; i++)
                            memcpy(ptrB + i * KPARTITION,
                                   b+nXferIndex + i * ldb,
                                   ((flagLastM && flagLastN)
                                    ? kCntNext : kCnt)*sizeof(float));
                        nCntPrev = nCntNext;
                        kCntPrev = (flagLastM && flagLastN) ? kCntNext : kCnt;
#endif
                    }
                    else if (flagLastM && flagLastN)
                    {
#if USE_EDMA
                        EdmaMgr_copy2D2DSep(chan1,
                                            b+nXferIndex, /* src */
                                            ptrB, /* dst */
                                            kCntNext*sizeof(float), /* num_bytes */
                                            nCntNext, /* num_lines */
                                            ldb*sizeof(float), /* src_pitch */
                                            KPARTITION*sizeof(float) /* dst_pitch */
                                            );
                        DSP_wbInv_L1D();  // ptrB
#else
                        for (i = 0; i < nCntNext; i++)
                            memcpy(ptrB + i * KPARTITION,
                                   b+nXferIndex + i * ldb,
                                   kCntNext*sizeof(float));
                        nCntPrev = nCntNext;
                        kCntPrev = kCntNext;
#endif
                    }
                    else
                    {
#if USE_EDMA
                        EdmaMgr_copyFast(chan1,
                                         b+nXferIndex, /* src */
                                         ptrB /* dst */
                                         );
                        DSP_wbInv_L1D();  // ptrB
#else
                        for (i = 0; i < nCntPrev; i++)
                            memcpy(ptrB + i * KPARTITION,
                                   b+nXferIndex + i * ldb,
                                   kCntPrev*sizeof(float));
#endif
                    }
                }

                // L2 memory assignment for B
                ptrB = (indexBCurrent == 0) ? ptrBSeg1: ptrBSeg2;
                // Corner case, zero out invalid cols in ptrB
                if (flagLastN && (nCnt % CORE_PROCESS_COLS != 0))
                {
                    for (i = 0; i < CORE_PROCESS_COLS -
                                    nCnt%CORE_PROCESS_COLS; i++)
                        memset((void *) (ptrB + (i+nCnt)*KPARTITION), 0,
                               kCnt * sizeof(float));
                }
                for(innerIndex_n = 0; innerIndex_n<nCnt; innerIndex_n+=CORE_PROCESS_COLS)
                {

                    dataMoveB(pL1, ptrB, kCnt);
                    ptrB += (CORE_PROCESS_COLS*KPARTITION);

                    // L2 memory assignment for B
                    ptrA = ptrASeg2;
                    // output memory assignment
                    ptrC= c + mIndex + (nIndex+innerIndex_n)*ldc;
                    for(innerIndex_m = 0; innerIndex_m<mCnt; innerIndex_m+=CORE_PROCESS_ROWS)
                    {
                        if (ldc % 2 != 0 ||
                            innerIndex_n + CORE_PROCESS_COLS > nCnt ||
                            innerIndex_m + CORE_PROCESS_ROWS > mCnt)
                        {
                            nValid = (nCnt - innerIndex_n > CORE_PROCESS_COLS)
                                     ? CORE_PROCESS_COLS : nCnt - innerIndex_n;
                            mValid = (mCnt - innerIndex_m > CORE_PROCESS_ROWS)
                                     ? CORE_PROCESS_ROWS : mCnt - innerIndex_m;
                            for (j = 0; j < nValid; j++)
                                for (i = 0; i < mValid; i++)
                                    ptrCTemp[j*CORE_PROCESS_ROWS + i] =
                                                           ptrC[j*ldc + i];
                            ptrCInternal = ptrCTemp;
                            ldcInternal  = CORE_PROCESS_ROWS;
                        }
                        else
                        {
                            ptrCInternal = ptrC;
                            ldcInternal  = ldc;;
                        }
                        // pre-fetch required A to L1 Cache
                        // 4xk * kx8 core matrix multiplications
                        __touch((const char *)ptrA,
                                CORE_PROCESS_ROWS * kCnt * sizeof(float));

                        sgemm_kernel(ptrA, pL1, ptrCInternal, alpha, kCnt,
                                     ldcInternal);

                        if (ldc % 2 != 0 ||
                            innerIndex_n + CORE_PROCESS_COLS > nCnt ||
                            innerIndex_m + CORE_PROCESS_ROWS > mCnt)
                        {
                            for (j = 0; j < nValid; j++)
                                for (i = 0; i < mValid; i++)
                                    ptrC[j*ldc + i] =
                                         ptrCTemp[j*CORE_PROCESS_ROWS + i];
                        }

                        // address of C to write to
                        ptrC += CORE_PROCESS_ROWS;
                        ptrA += (CORE_PROCESS_ROWS*KPARTITION);

                    } // inner loop m

                } // inner loop n
            } // n loop
        } // m loop
    } // k loop

#if USE_EDMA
    if (pMsmc)  EdmaMgr_free(chan0);
    EdmaMgr_free(chan1);
#endif
}

// sgemm row major interface, calls sgemm in turn
// Requirements: 16KB  of L1 SRAM, passed in as pL1
// Requirements: 128KB of L2 SRAM, passed in as pL2
void sgemm_rowmajor(
                 const int M, const int N, const int K,
                 const float alpha, float* restrict A, const int lda,
                 float* restrict B, const int ldb,
                 const float beta, float* restrict C, const int ldc,
                 int NUMAPANELS, int NUMBPANELS,
                 float* restrict pL1, float* restrict pL2,
                 float* restrict pMsmc, int tid)
{
       /*----------------------------------------------------------------------
       * Convert RowMajor computation to ColumnMajor computation
       * Fact: Mem_Layout(A_RowMajor) === Mem_Layout(Transpose(A)_ColMajor)
       * Therefore: C_RowMajor = A_RowMajor * B_RowMajor
       *            C[mxn] = A[mxk] * B[kxn]
       * can be computed as:
       * Transpose(C)_ColMajor = Transpose(B)_ColMajor * Transpose(A)_ColMajor
       * C'[nxm] = B'[nxk] * A'[kxm],
       * where ptrC' === ptrC, ptrA' === ptrA, ptrB' === ptrB
       * So, all we need to do is to: swap(m, n), swap(a, b)
       * ldA'_col = ldA_row = k, 
       * ldB'_col = ldB_row = n, 
       * ldC'_col = ldC_row = n, 
       *---------------------------------------------------------------------*/
       sgemm(N, M, K, alpha, B, ldb, A, lda, beta, C, ldc, 
             NUMAPANELS, NUMBPANELS, pL1, pL2, pMsmc, tid);
}



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值