#include <immintrin.h>
#include <stdio.h> // 用于调试时的printf
// M, N, K 维度定义
// A: M x K, 列主序
// B: K x N, 行主序
// C: M x N, 行主序
/* 矩阵元素访问宏 */
// A_T 是列主序
#define A_T(i,j,lda,aptr) (aptr)[(j)*(lda) + (i)]
// B 是行主序
#define B_RM(i,j,ldb,bptr) (bptr)[(i)*(ldb) + (j)]
// C 是行主序
#define C_RM(i,j,ldc,cptr) (cptr)[(i)*(ldc) + (j)]
// min 函数宏
#define min( i, j ) ( (i)<(j) ? (i): (j) )
// 分块大小
#ifndef MC
#define MC 240 // M 维度的块大小 (A的行,C的行) 8的倍数
#endif
#ifndef NC
#define NC 768 // N 维度的块大小 (B的列,C的列) 24的倍数
#endif
#ifndef KC
#define KC 128 // K 维度的块大小 (A的列,B的行)
#endif
#define M_TILE 8
#define N_TILE 24
// AddDot
static void AddDot8x24( int k_iter,
const double *a_tile, int lda_full, /* M_dim for A */
const double *b_tile, int ldb_full, /* N_dim for B */
double *c_tile, int ldc_full ); /* N_dim for C */
// Kernel_McNcKc 函数:计算 C 的一个 MC x NC 子块
static void Kernel_McNcKc( int current_mc, int current_nc, int current_kc,
const double *a_block, int global_lda,
const double *b_block, int global_ldb,
double *c_block, int global_ldc)
{
int i_block_loop, j_block_loop;
// 每次处理一个 M_TILE x N_TILE 的tile
for (i_block_loop = 0; i_block_loop < current_mc; i_block_loop += M_TILE) {
for (j_block_loop = 0; j_block_loop < current_nc; j_block_loop += N_TILE) {
const double *a_tile_ptr = &A_T(i_block_loop, 0, global_lda, a_block);
const double *b_tile_ptr = &B_RM(0, j_block_loop, global_ldb, b_block);
double *c_tile_ptr = &C_RM(i_block_loop, j_block_loop, global_ldc, c_block);
AddDot8x24(current_kc, /*即K维度子块的大小 */
a_tile_ptr, global_lda,
b_tile_ptr, global_ldb,
c_tile_ptr, global_ldc);
}
}
}
void MY_MMult( int m_dim, int n_dim, int k_dim,
const double *global_a_ptr, int lda_full, /* A的真实lda (等于M) */
const double *global_b_rm_ptr, int ldb_full, /* B的真实列数 (等于N) */
double *global_c_rm_ptr, int ldc_full ) /* C的真实列数 (等于N) */
{
int ic_loop, pc_loop, jc_loop;
int current_mc_val, current_nc_val, current_kc_val;
//先n
for (jc_loop = 0; jc_loop < n_dim; jc_loop += NC) {
current_nc_val = min(n_dim - jc_loop, NC);
//相当于没有
for (pc_loop = 0; pc_loop < k_dim; pc_loop += KC) {
current_kc_val = min(k_dim - pc_loop, KC);
const double *a_panel_for_kc_block = &A_T(0, pc_loop, lda_full, global_a_ptr);
const double *b_panel_for_kc_block = &B_RM(pc_loop, 0, ldb_full, global_b_rm_ptr);
for (ic_loop = 0; ic_loop < m_dim; ic_loop += MC) {
current_mc_val = min(m_dim - ic_loop, MC);
// A: A(ic_loop : ic_loop+current_mc_val-1, pc_loop : pc_loop+current_kc_val-1)
// B: B(pc_loop : pc_loop+current_kc_val-1, jc_loop : jc_loop+current_nc_val-1)
// C: C(ic_loop : ic_loop+current_mc_val-1, jc_loop : jc_loop+current_nc_val-1)
// Calculate start pointers for the panel/sub-block for Kernel_McNcKc
// A panel starts at global A [ic_loop, pc_loop]
const double *a_panel_ptr = &A_T(ic_loop, 0, lda_full, a_panel_for_kc_block);
// B panel starts at global B [pc_loop, jc_loop]
const double *b_panel_ptr = &B_RM(0, jc_loop, ldb_full, b_panel_for_kc_block);
// C sub-block starts at global C [ic_loop, jc_loop]
double *c_sub_tile_ptr = &C_RM(ic_loop, jc_loop, ldc_full, global_c_rm_ptr);
Kernel_McNcKc(current_mc_val, current_nc_val, current_kc_val,
a_panel_ptr, lda_full,
b_panel_ptr, ldb_full,
c_sub_tile_ptr, ldc_full);
}
}
}
}
static void AddDot8x24( int k_iter,
const double *a_tile, int lda_full,
const double *b_tile, int ldb_full,
double *c_tile, int ldc_full )
{
__m512d c_reg_00, c_reg_01, c_reg_02;
__m512d c_reg_10, c_reg_11, c_reg_12;
__m512d c_reg_20, c_reg_21, c_reg_22;
__m512d c_reg_30, c_reg_31, c_reg_32;
__m512d c_reg_40, c_reg_41, c_reg_42;
__m512d c_reg_50, c_reg_51, c_reg_52;
__m512d c_reg_60, c_reg_61, c_reg_62;
__m512d c_reg_70, c_reg_71, c_reg_72;
c_reg_00 = _mm512_load_pd(c_tile + 0*ldc_full + 0*8);
c_reg_01 = _mm512_load_pd(c_tile + 0*ldc_full + 1*8);
c_reg_02 = _mm512_load_pd(c_tile + 0*ldc_full + 2*8);
c_reg_10 = _mm512_load_pd(c_tile + 1*ldc_full + 0*8);
c_reg_11 = _mm512_load_pd(c_tile + 1*ldc_full + 1*8);
c_reg_12 = _mm512_load_pd(c_tile + 1*ldc_full + 2*8);
c_reg_20 = _mm512_load_pd(c_tile + 2*ldc_full + 0*8);
c_reg_21 = _mm512_load_pd(c_tile + 2*ldc_full + 1*8);
c_reg_22 = _mm512_load_pd(c_tile + 2*ldc_full + 2*8);
c_reg_30 = _mm512_load_pd(c_tile + 3*ldc_full + 0*8);
c_reg_31 = _mm512_load_pd(c_tile + 3*ldc_full + 1*8);
c_reg_32 = _mm512_load_pd(c_tile + 3*ldc_full + 2*8);
c_reg_40 = _mm512_load_pd(c_tile + 4*ldc_full + 0*8);
c_reg_41 = _mm512_load_pd(c_tile + 4*ldc_full + 1*8);
c_reg_42 = _mm512_load_pd(c_tile + 4*ldc_full + 2*8);
c_reg_50 = _mm512_load_pd(c_tile + 5*ldc_full + 0*8);
c_reg_51 = _mm512_load_pd(c_tile + 5*ldc_full + 1*8);
c_reg_52 = _mm512_load_pd(c_tile + 5*ldc_full + 2*8);
c_reg_60 = _mm512_load_pd(c_tile + 6*ldc_full + 0*8);
c_reg_61 = _mm512_load_pd(c_tile + 6*ldc_full + 1*8);
c_reg_62 = _mm512_load_pd(c_tile + 6*ldc_full + 2*8);
c_reg_70 = _mm512_load_pd(c_tile + 7*ldc_full + 0*8);
c_reg_71 = _mm512_load_pd(c_tile + 7*ldc_full + 1*8);
c_reg_72 = _mm512_load_pd(c_tile + 7*ldc_full + 2*8);
__m512d a_bcast_0, a_bcast_1, a_bcast_2, a_bcast_3;
__m512d b_vec_0, b_vec_1, b_vec_2;
const double *a_k_ptr = a_tile;
const double *b_k_ptr = b_tile;
for (int p = 0; p < k_iter; ++p) {
b_vec_0 = _mm512_load_pd(b_k_ptr + 0*8);
b_vec_1 = _mm512_load_pd(b_k_ptr + 1*8);
b_vec_2 = _mm512_load_pd(b_k_ptr + 2*8);
a_bcast_0 = _mm512_set1_pd(*(a_k_ptr + 0));
a_bcast_1 = _mm512_set1_pd(*(a_k_ptr + 1));
a_bcast_2 = _mm512_set1_pd(*(a_k_ptr + 2));
a_bcast_3 = _mm512_set1_pd(*(a_k_ptr + 3));
c_reg_00 = _mm512_fmadd_pd(a_bcast_0, b_vec_0, c_reg_00);
c_reg_01 = _mm512_fmadd_pd(a_bcast_0, b_vec_1, c_reg_01);
c_reg_02 = _mm512_fmadd_pd(a_bcast_0, b_vec_2, c_reg_02);
c_reg_10 = _mm512_fmadd_pd(a_bcast_1, b_vec_0, c_reg_10);
c_reg_11 = _mm512_fmadd_pd(a_bcast_1, b_vec_1, c_reg_11);
c_reg_12 = _mm512_fmadd_pd(a_bcast_1, b_vec_2, c_reg_12);
c_reg_20 = _mm512_fmadd_pd(a_bcast_2, b_vec_0, c_reg_20);
c_reg_21 = _mm512_fmadd_pd(a_bcast_2, b_vec_1, c_reg_21);
c_reg_22 = _mm512_fmadd_pd(a_bcast_2, b_vec_2, c_reg_22);
c_reg_30 = _mm512_fmadd_pd(a_bcast_3, b_vec_0, c_reg_30);
c_reg_31 = _mm512_fmadd_pd(a_bcast_3, b_vec_1, c_reg_31);
c_reg_32 = _mm512_fmadd_pd(a_bcast_3, b_vec_2, c_reg_32);
a_bcast_0 = _mm512_set1_pd(*(a_k_ptr + 4));
a_bcast_1 = _mm512_set1_pd(*(a_k_ptr + 5));
a_bcast_2 = _mm512_set1_pd(*(a_k_ptr + 6));
a_bcast_3 = _mm512_set1_pd(*(a_k_ptr + 7));
c_reg_40 = _mm512_fmadd_pd(a_bcast_0, b_vec_0, c_reg_40);
c_reg_41 = _mm512_fmadd_pd(a_bcast_0, b_vec_1, c_reg_41);
c_reg_42 = _mm512_fmadd_pd(a_bcast_0, b_vec_2, c_reg_42);
c_reg_50 = _mm512_fmadd_pd(a_bcast_1, b_vec_0, c_reg_50);
c_reg_51 = _mm512_fmadd_pd(a_bcast_1, b_vec_1, c_reg_51);
c_reg_52 = _mm512_fmadd_pd(a_bcast_1, b_vec_2, c_reg_52);
c_reg_60 = _mm512_fmadd_pd(a_bcast_2, b_vec_0, c_reg_60);
c_reg_61 = _mm512_fmadd_pd(a_bcast_2, b_vec_1, c_reg_61);
c_reg_62 = _mm512_fmadd_pd(a_bcast_2, b_vec_2, c_reg_62);
c_reg_70 = _mm512_fmadd_pd(a_bcast_3, b_vec_0, c_reg_70);
c_reg_71 = _mm512_fmadd_pd(a_bcast_3, b_vec_1, c_reg_71);
c_reg_72 = _mm512_fmadd_pd(a_bcast_3, b_vec_2, c_reg_72);
a_k_ptr += lda_full;
b_k_ptr += ldb_full;
}
_mm512_store_pd(c_tile + 0*ldc_full + 0*8, c_reg_00);
_mm512_store_pd(c_tile + 0*ldc_full + 1*8, c_reg_01);
_mm512_store_pd(c_tile + 0*ldc_full + 2*8, c_reg_02);
_mm512_store_pd(c_tile + 1*ldc_full + 0*8, c_reg_10);
_mm512_store_pd(c_tile + 1*ldc_full + 1*8, c_reg_11);
_mm512_store_pd(c_tile + 1*ldc_full + 2*8, c_reg_12);
_mm512_store_pd(c_tile + 2*ldc_full + 0*8, c_reg_20);
_mm512_store_pd(c_tile + 2*ldc_full + 1*8, c_reg_21);
_mm512_store_pd(c_tile + 2*ldc_full + 2*8, c_reg_22);
_mm512_store_pd(c_tile + 3*ldc_full + 0*8, c_reg_30);
_mm512_store_pd(c_tile + 3*ldc_full + 1*8, c_reg_31);
_mm512_store_pd(c_tile + 3*ldc_full + 2*8, c_reg_32);
_mm512_store_pd(c_tile + 4*ldc_full + 0*8, c_reg_40);
_mm512_store_pd(c_tile + 4*ldc_full + 1*8, c_reg_41);
_mm512_store_pd(c_tile + 4*ldc_full + 2*8, c_reg_42);
_mm512_store_pd(c_tile + 5*ldc_full + 0*8, c_reg_50);
_mm512_store_pd(c_tile + 5*ldc_full + 1*8, c_reg_51);
_mm512_store_pd(c_tile + 5*ldc_full + 2*8, c_reg_52);
_mm512_store_pd(c_tile + 6*ldc_full + 0*8, c_reg_60);
_mm512_store_pd(c_tile + 6*ldc_full + 1*8, c_reg_61);
_mm512_store_pd(c_tile + 6*ldc_full + 2*8, c_reg_62);
_mm512_store_pd(c_tile + 7*ldc_full + 0*8, c_reg_70);
_mm512_store_pd(c_tile + 7*ldc_full + 1*8, c_reg_71);
_mm512_store_pd(c_tile + 7*ldc_full + 2*8, c_reg_72);
}
请告诉我这段代码完成了什么优化工作,请详细说明
最新发布