一、学习目标
1. 掌握TIKC++常用数据结构
2. 了解3/2/0级接口API的概念
二、数据结构
1. GlobalTensor
GlobalTensor用来存放Global Memory(外部存储)的全局数据
2. LocalTensor
LocalTensor用于存放核上Local Memory(内部存储)的数据
三、矢量计算指令接口
矢量计算指令接口,能够启动AI Core中的Vector单元执行计算
为了降低开发者的使用门槛,指令按照由易到难,分成了3级到0级接口。其中3级接口最为简单,0级接口最为复杂,(1级接口还未发布)
多层级API封装的作用:
- 降低复杂指令的使用难度
- 跨代兼容性保障
- 保留最大灵活度的可能
四、多级接口API
1. 3级接口
3级接口,运算符重载,支持+, -, *, /, |, &, ^, >, < , >=, <=,!=,==实现2级接口的简化表达
允许用户使用形如:dst = src0 ※ src1,针对整个Tensor进行计算
以下指令API拥有3级接口:
Add:dstLocal = src0Local + src1Local;
Sub:dstLocal = src0Local - src1Local;
Mul:dstLocal = src0Local * src1Local;
Div:dstLocal = src0Local / src1Local;
And:dstLocal = src0Local & src1Local;
Or:dstLocal = src0Local | src1Local;
Compare:
dstLocal = src0Local < src1Local;
dstLocal = src0Local > src1Local;
dstLocal = src0Local <= src1Local;
dstLocal = src0Local >= src1Local;
dstLocal = src0Local == src1Local;
dstLocal = src0Local != src1Local;
注意:三级接口会进行连续矢量运算,运算量为目的LocalTensor的总长度
2. 2级接口
2级连续计算接口,针对源操作数srcLocal的连续COUNT个数据进行计算,并连续写入目的操作数dstLocal,提供了一维Tensor的连续COUNT个数据的计算支持
允许用户使用形如:
void Operator(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const int32_t& calCount)
大多数指令API拥有2级接口,2级接口相对于3级接口,可以自定义运算量
- Exp(dstLocal, srcLocal, 512);
- Adds(dstLocal, srcLocal, scalarValue, 512);
- Select(dstLocal, maskLocal, src0Local, src1Local, SELMODE::VSEL_CMPMASK_SPR, 256);
- ReduceMin(dstLocal, srcLocal, workLocal, 8320, true);
- Duplicate(dstLocal, inputVal, 256);
注意:二级接口会进行连续矢量运算,开发者指定的运算量不能超过参与运算Tensor本身的大小
3. 0级接口
0级功能灵活计算接口,是最底层的开发接口,可以完整发挥硬件优势的计算API,可以进行非连续的计算
该功能可以充分发挥CANN系列芯片的强大功能指令,支持对每个操作数的Block stride,Repeat stride,MASK的操作,允许用户使用诸多的通用参数来定制化所需要的操作
通用参数包括:
- Repeat times(迭代的次数)
- Block stride(单次迭代内不同block间地址步长)
- Repeat stride(相邻迭代间相同block的地址步长)
- Mask(用于控制参与运算的计算单元)
允许用户使用形如:
Mask逐比特模式
template <typename T> __aicore__ inline void Exp(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, uint64_t mask[2], const uint8_t repeatTimes, const UnaryRepeatParams& repeatParams);
Mask连续模式
template <typename T> __aicore__ inline void Exp(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, uint64_t mask, const uint8_t repeatTimes, const UnaryRepeatParams& repeatParams);
(1)重复迭代次数-Repeat times
矢量计算单元,一次最多可以计算256Bytes的数据,每次读取连续的8个block(每个block 32Bytes,共256Bytes)数据进行计算,为完成对输入数据的处理,必须通过多次迭代(repeat)才能完成所有数据的读取与计算。
(2)相邻迭代间相同block的地址步长-Repeat stride
当Repeat times大于1,需要多次迭代完成矢量计算时,可以根据不同的使用场景合理设置相邻迭代间相同block的地址步长Repeat stride的值
- 连续计算场景:假设定义一个Tensor供目的操作数和源操作数同时使用(即地址重叠),Repeat stride取值为8。此时,矢量计算单元第一次迭代读取连续8个block,第二轮迭代读取下一个连续的8个block,通过多次迭代即可完成所有输入数据的计算
- 非连续计算场景:Repeat stride取值大于8(如取10)时,则相邻迭代间矢量计算单元读取的数据在地址上不连续,出现2个block的间隔
- 反复计算场景:Repeat stride取值为0时,矢量计算单元会对首个连续的8个block进行反复读取和计算
- 部分重复计算:Repeat stride取值大于0且小于8时,相邻迭代间部分数据会被矢量计算单元重复读取和计算,此种情形一般场景不涉及
(3)同一迭代内不同block的地址步长-Block stride
如果需要控制单次迭代内,数据处理的步长,可以通过设置同一迭代内不同block的地址步长Block stride来实现。
- 连续计算,Block stride 设置为1,对同一迭代内的8个block数据连续进行处理
- 非连续计算,Block stride值大于1(如取2),同一迭代内不同block之间在读取数据时出现一个block的间隔
(4)Mask
Mask用于控制每次迭代内参与计算的元素。可通过连续模式和逐比特模式两种方式进行设置
- 连续模式:表示前面连续的多少个元素参与计算。数据类型为uint64_t。取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同(当前数据类型单次迭代时能处理的元素个数最大值为:256 / sizeof(数据类型))。当操作数的数据类型占比特位16位时(如half,uint16_t),mask∈[1, 128];当操作数为32位时(如float, int32_t),mask∈[1, 64]。
- 逐比特模式:可以按位控制哪些元素参与计算,比特位的值为1表示参与计算,0表示不参与。参数类型为长度为2的uint64_t类型数组
参数取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。当操作数为16位时,mask[0]、mask[1]∈[0, 264-1];当dst/src为32位时,mask[1]为0,mask[0]∈[0, 264-1]
4. 以下指令形式表示的含义相同
// int16_t数据类型, dstLocal长度为512个int16_t
// 0级接口样例-mask连续模式
uint64_t mask = 128;
// repeatTimes = 4, 一次迭代计算128个数, 共计算512个数
// dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内数据连续读取和写入
// dstRepStride, src0RepStride, src1RepStride = 8, 相邻迭代间数据连续读取和写入
Add(dstLocal, src0Local, src1Local, mask, 4, { 1, 1, 1, 8, 8, 8 });
// 0级接口样例-mask逐bit模式
uint64_t mask[2] = { UINT64_MAX, UINT64_MAX };
// repeatTimes = 4, 一次迭代计算128个数, 共计算512个数
// dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内数据连续读取和写入
// dstRepStride, src0RepStride, src1RepStride = 8, 相邻迭代间数据连续读取和写入
Add(dstLocal, src0Local, src1Local, mask, 4, { 1, 1, 1, 8, 8, 8 });
// 2级接口样例
Add(dstLocal, src0Local, src1Local, 512);
// 3级接口样例
dstLocal = src0Local + src1Local;