https://github.com/tpoisonooo/how-to-optimize-gemm/blob/master/aarch64/output_MMult_4x4_10.m 学习笔记
前言
看了半天,不好懂,记一下
零、MNK
mnk是举证乘法中的行列数
C=C+AB
C是mxn的矩阵,A是mxk的矩阵,B是kxn的矩阵
一、函数之前
#ifdef __ARM_NEON
#include <arm_neon.h>
#else
#error("arm neon not supported")
#endif
这段是进行一个检测,检测当前平台架构是否是arm架构,如果是arm架构的话,会自动生成__ARM_NEON
如果不是arm的话,报错
/* Block sizes */
#define mc 256
#define kc 128
设置分块大小
/* Create macros so that the matrices are stored in row-major order */
#define A(i, j) a[(i) * lda + (j)]
#define B(i, j) b[(i) * ldb + (j)]
#define C(i, j) c[(i) * ldc + (j)]
#define min(i, j) ((i) < (j) ? (i) : (j))
定义major row(行优先存储),因为lda()与i相乘,所以是行优先,如果与j相乘则为列优先。
二、MY_MMult、InnerKernel以及AddDot4x4
一共是做了三层函数(这不正好对应了cuda中的grid、block、thread吗?)
最外层是MY_MMult、其次是InnerKernel,最内层是AddDot4x4
最外层矩阵分块,中间层按照4x4的大小来进行矩阵相乘,最内层是矩阵乘法的实现(4x4)
void MY_MMult(int m, int n, int k, float *a, int lda, float *b, int ldb,
float *c, int ldc) {
int i, p, pb, ib;
for (p = 0; p < k; p += kc) {
pb = min(k - p, kc);
for (i = 0; i < m; i += mc) {
ib = min(m - i, mc);
InnerKernel(ib, n, pb, &A(i, p), lda, &B(p, 0), ldb, &C(i, 0), ldc);
}
}
}
注:pb=min(k-p,kc),kc是规则,按照kc的大小来对矩阵A的列和矩阵B的行进行分块,但是实际过程中,会出现到最后分不到kc大小的块,于是需要一个min操作保证列数始终是正确的
ib=min(m-i,mc)同理,不过是对矩阵A的行进行分块。
void InnerKernel(int m, int n, int k, float *a, int lda, float *b, int ldb,
float *c, int ldc) {
int i, j;
for (j = 0; j < n; j += 4) { /* Loop over the columns of C, unrolled by 4 */
for (i = 0; i < m; i += 4) { /* Loop over the rows of C */
/* Update C( i,j ), C( i,j+1 ), C( i,j+2 ), and C( i,j+3 ) in
one routine (four inner products) */
AddDot4x4(k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j), ldc);
}
}
}
简单的两个循环对已分块的矩阵进行4x4的划分
void AddDot4x4(int k, float *a, int lda, float *b, int ldb, float *c, int ldc) {
/* So, this routine computes a 4x4 block of matrix A
C( 0, 0 ), C( 0, 1 ), C( 0, 2 ), C( 0, 3 ).
C( 1, 0 ), C( 1, 1 ), C( 1, 2 ), C( 1, 3 ).
C( 2, 0 ), C( 2, 1 ), C( 2, 2 ), C( 2, 3 ).
C( 3, 0 ), C( 3, 1 ), C( 3, 2 ), C( 3, 3 ).
Notice that this routine is called with c = C( i, j ) in the
previous routine, so these are actually the elements
C( i , j ), C( i , j+1 ), C( i , j+2 ), C( i , j+3 )
C( i+1, j ), C( i+1, j+1 ), C( i+1, j+2 ), C( i+1, j+3 )
C( i+2, j ), C( i+2, j+1 ), C( i+2, j+2 ), C( i+2, j+3 )
C( i+3, j ), C( i+3, j+1 ), C( i+3, j+2 ), C( i+3, j+3 )
in the original matrix C
In this version, we use registers for elements in the current row
of B as well */
float
/* Point to the current elements in the four rows of A */
*a_0p_pntr,
*a_1p_pntr, *a_2p_pntr, *a_3p_pntr;
a_0p_pntr = &A(0, 0);
a_1p_pntr = &A(1, 0);
a_2p_pntr = &A(2, 0);
a_3p_pntr = &A(3, 0);
float32x4_t c_0p_sum = { 0 };
float32x4_t c_1p_sum = { 0 };
float32x4_t c_2p_sum = { 0 };
float32x4_t c_3p_sum = { 0 };
register float a_0p_reg, a_1p_reg, a_2p_reg, a_3p_reg;
for (int p = 0; p < k; ++p) {
float32x4_t b_reg = vld1q_f32(&B(p, 0));
a_0p_reg = *a_0p_pntr++;
a_1p_reg = *a_1p_pntr++;
a_2p_reg = *a_2p_pntr++;
a_3p_reg = *a_3p_pntr++;
c_0p_sum = vmlaq_n_f32(c_0p_sum, b_reg, a_0p_reg);
c_1p_sum = vmlaq_n_f32(c_1p_sum, b_reg, a_1p_reg);
c_2p_sum = vmlaq_n_f32(c_2p_sum, b_reg, a_2p_reg);
c_3p_sum = vmlaq_n_f32(c_3p_sum, b_reg, a_3p_reg);
}
float *c_pntr = 0;
c_pntr = &C(0, 0);
float32x4_t c_reg = vld1q_f32(c_pntr);
c_reg = vaddq_f32(c_reg, c_0p_sum);
vst1q_f32(c_pntr, c_reg);
c_pntr = &C(1, 0);
c_reg = vld1q_f32(c_pntr);
c_reg = vaddq_f32(c_reg, c_1p_sum);
vst1q_f32(c_pntr, c_reg);
c_pntr = &C(2, 0);
c_reg = vld1q_f32(c_pntr);
c_reg = vaddq_f32(c_reg, c_2p_sum);
vst1q_f32(c_pntr, c_reg);
c_pntr = &C(3, 0);
c_reg = vld1q_f32(c_pntr);
c_reg = vaddq_f32(c_reg, c_3p_sum);
vst1q_f32(c_pntr, c_reg);
}
这其中使用了几个neon指令
首先从float指针开始,定义了四个float指针类型,每个指针分别指向A矩阵的四行
定义了四个float32x4_t(表示一个包含四个单精度浮点数的向量,SIMD思想)类型的变量,正好也对应了4x4的矩阵运算
使用register 关键字定义了四个float型变量在cpu寄存器中
进入第一个循环,对A的列B的行进行循环,每次移动一个长度,循环计数是p
使用vld1q_f32指令将B的第p行写入到float32x4_t类型的neon寄存器中
每次循环指针后移地将A矩阵的每一行的单个元素放到register float类型的变量中
用四个float32x4_t的sum变量来接收利用vmlaq_n_f32指令的累乘结果,内层是一个saxpy修正
saxpy修正(单行代码举例):
c_0p_sum = vmlaq_n_f32(c_0p_sum, b_reg, a_0p_reg);
这其中的b_reg和c_0p_sum是一个float32x4_t类型,是一个包含4个单精度浮点数的向量,而后面的a_0p_reg是一个float型标量,即saxpy型修正(y=ax+y,a为标量,标量乘以向量),并将结果存储到c_0p_sum中。
后续的四个部分代码就是将C的四行进行更新,将刚刚算出来的sum值与原本的C对应的位置的值相加再更新:指针指向C的相关位置,取出此位置的值于c_reg,再求c_reg和sum的和,最后更新C。
总结
此止