完整的测试仓库参考添加链接描述
causal softmax的手写算子
这里为了加速代码,专门针对ndim=2,ndim=3做了特殊处理。
#include "bang.h"
#include "cnrt.h"
const int NRAM_MAX_SIZE = 1024 * 256;
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template<typename T>
__mlu_global__ void causal_softmaxKernel(T *destination, int *strideDest, int *shape, int othersize, int dimsize, int dimS, int mask, int ndim) {
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
const int maxNum = SRC_MAX_SIZE / sizeof(T);
int wSize = 128 / sizeof(T);
__nram__ T srcMax[2];
if (dimsize > maxNum) {
T *src = (T *) nram_buffer; //[maxNum]
T *destSum = src + maxNum; //[maxNum]
T *destSumFinal = destSum + maxNum;//[wSize]
T *tmp = destSumFinal + wSize; //[maxNum]
T destOldMax;
T destNewMax;
int remain = dimsize % maxNum;
int repeat = (dimsize - remain) / maxNum;
int remainT = othersize % taskDim;
int stepEasy = (othersize - remainT) / taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < remainT ? stepHard : stepEasy);
int indStart = (taskId < remainT ? taskId * stepHard : (taskId - remainT) * stepEasy + remainT * stepHard);
for (int i = indStart; i < indStart + step; i++) {
int indd = 0;
int indi = i;
int lastI = indi % shape[ndim - 2];
for (int j = ndim - 2; j >= 0; --j) {
indd += (indi % shape[j]) * strideDest[j];
indi /= shape[j];
}
if (mask + 1 + lastI < maxNum) {
__bang_write_value(src, maxNum, -INFINITY); //提前设置负无穷
__memcpy(src, destination + indd, (mask + 1 + lastI) * sizeof(T), GDRAM2NRAM);//从destination读取对应数据
__bang_argmax(srcMax, src, maxNum); //获取最大值
__bang_write_value(destSum, maxNum, srcMax[0]);
__memcpy(destSum, src, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);//destSum前面(mask + 1 + lastI)为src,后面部分为最大值
__bang_sub_scalar(destSum, destSum, srcMax[0], maxNum); //destSum前面(mask + 1 + lastI)为(src - M),后面部分为0
__bang_active_exp_less_0(destSum, destSum, maxNum); //destSum前面(mask + 1 + lastI)为exp(src - M),后面部分为1
__bang_write_zero(src, maxNum); //重新设置src全部为0
__memcpy(src, destSum, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);//src前面(mask + 1 + lastI)为exp(src - M),后面部分为0
if (maxNum >= wSize) {
int segNum = maxNum / wSize;//准备数值求和
for (int strip = segNum / 2; strip > 0; strip = strip / 2) {
for (int j = 0; j < strip; j++) {
__bang_add(destSum + j * wSize, destSum + j * wSize, destSum + (j + strip) * wSize, wSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, wSize);//此时destSum[0]保存的就是当前maxNum长度数据的数值和
} else {
__memcpy(destSumFinal, destSum, maxNum * sizeof(T), NRAM2NRAM);
__bang_reduce_sum(destSumFinal, destSumFinal, wSize);//此时destSum[0]保存的就是当前maxNum长度数据的数值和
}
T globalSumInv = 1.0 / (destSumFinal[0] - (maxNum - (mask + 1 + lastI)));//下面开始指数变换,写回GDRAM
__bang_mul_scalar(src, src, globalSumInv, maxNum);
__memcpy(destination + indd, src, maxNum * sizeof(T), NRAM2GDRAM);
__bang_write_zero(src, maxNum);
for (int s = 1; s < repeat; s++) {
__memcpy(destination + indd + s * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM);
}
if (remain) {
__memcpy(destination + indd + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM);
}
} else {
int newRemain = (mask + 1 + lastI) % maxNum;
int nR = (mask + 1 + lastI - newRemain) / maxNum;
__bang_write_zero(destSum, maxNum);
__bang_write_zero(destSumFinal, wSize);
destOldMax = -INFINITY;
destNewMax = -INFINITY;
for (int s = 0; s < nR; s++) {
__memcpy(src, destination + indd + s * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
__bang_argmax(srcMax, src, maxNum);
if (destNewMax < srcMax[0]) {
destNewMax = srcMax[0];
}
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);
if (s > 0) {
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
}
__bang_add(destSum, destSum, src, maxNum);
destOldMax = destNewMax;
}
if (newRemain) {
//__bang_write_value(src, maxNum, -INFINITY);
__memcpy(src, destination + indd + nR * maxNum, newRemain * sizeof(T), GDRAM2NRAM);
__bang_argmax(srcMax, src, maxNum);
if (destNewMax < srcMax[0]) {
destNewMax = srcMax[0];
}
__bang_write_value(tmp, maxNum, destNewMax);
__memcpy(tmp, src, newRemain * sizeof(T), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, maxNum);
__bang_active_exp_less_0(tmp, tmp, maxNum);
if (nR > 0) {
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
}
__bang_add(destSum, destSum, tmp, maxNum);
destOldMax = destNewMax;
}
if (maxNum >= wSize) {
int segNum = maxNum / wSize;//准备数值求和
for (int strip = segNum / 2; strip > 0; strip = strip / 2) {
for (int j = 0; j < strip; j++) {
__bang_add(destSum + j * wSize, destSum + j * wSize, destSum + (j + strip) * wSize, wSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, wSize);//此时destSum[0]保存的就是当前maxNum长度数据的数值和
} else {
__memcpy(destSumFinal, destSum, maxNum * sizeof(T), NRAM2NRAM);
__bang_reduce_sum(destSumFinal, destSumFinal, wSize);//此时destSum[0]保存的就是当前maxNum长度数据的数值和
}
T globalSumInv;
if (newRemain) {
globalSumInv = 1.0 / (destSumFinal[0] - (maxNum - newRemain));//下面开始指数变换,写回GDRAM
} else {
globalSumInv = 1.0 / destSumFinal[0];//下面开始指数变换,写回GDRAM
}
for (int s = 0; s < nR; s++) {
__memcpy(src, destination + indd + s * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);
__bang_mul_scalar(src, src, globalSumInv, maxNum);
__memcpy(destination + indd + s * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM);
}
__bang_write_zero(src, maxNum);
for (int s = nR; s < repeat; s++) {
__memcpy(destination + indd + s * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM);
}
if (remain) {
__memcpy(destination + indd + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM);
}
if (newRemain) {
__memcpy(src, destination + indd + nR * maxNum, newRemain * sizeof(T), GDRAM2NRAM);
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);