深入理解TVM中的TensorIR抽象
TensorIR是Apache TVM中用于表示张量计算的核心抽象,作为机器学习编译框架的重要组成部分,它能够精确描述循环结构以及硬件加速相关的各种特性,包括线程并行、专用硬件指令使用和内存访问模式等。
从NumPy到TensorIR
让我们从一个具体的例子开始:两个128×128矩阵A和B的计算过程,包含矩阵乘法和ReLU激活函数:
Y = A @ B
C = relu(Y) = max(Y, 0)
NumPy实现
首先看NumPy风格的实现,这有助于理解基本计算逻辑:
def numpy_mm_relu(A, B, C):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j in range(128):
for k in range(128):
if k == 0:
Y[i, j] = 0
Y[i, j] += A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
TensorIR实现
对应的TensorIR实现如下(使用TVMScript语法):
@tvm.script.ir_module
class MyModule:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
TensorIR核心概念解析
1. 函数参数与缓冲区
TensorIR使用T.Buffer
类型明确指定输入输出张量的形状和数据类型:
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
中间结果使用T.alloc_buffer
分配,类似于NumPy的empty
:
Y = T.alloc_buffer((128, 128), dtype="float32")
2. 循环迭代
TensorIR提供了T.grid
语法糖简化嵌套循环的编写:
for i, j, k in T.grid(128, 128, 128):
等价于:
for i in range(128):
for j in range(128):
for k in range(128):
3. 计算块(Block)
TensorIR的核心创新点是引入了T.block
概念:
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
每个块包含:
- 块轴(block axes)定义:
vi
,vj
,vk
- 计算逻辑
- 初始化部分(可选)
4. 块轴属性
块轴通过T.axis
定义,有三种关键属性:
vi = T.axis.spatial(128, i) # 空间轴
vj = T.axis.spatial(128, j) # 空间轴
vk = T.axis.reduce(128, k) # 归约轴
- 空间轴(spatial):对应输出张量的空间维度
- 归约轴(reduce):表示需要归约计算的维度(如矩阵乘法的k维度)
5. 块的自包含性
块定义包含了完整的计算语义,独立于外部循环结构。这种设计:
- 使优化变换更加局部化
- 便于验证循环与块的匹配关系
- 支持更灵活的调度变换
6. 语法糖:轴重映射
对于简单映射关系,可以使用T.axis.remap
简化代码:
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
其中"SSR"表示三个轴的类型依次为:空间、空间、归约。
TensorIR的设计优势
- 显式硬件特性表达:直接描述并行、向量化等硬件特性
- 优化友好:块结构便于自动和手动优化
- 验证保障:类型和形状信息帮助捕获错误
- 多后端支持:可针对不同硬件生成高效代码
总结
TensorIR作为TVM的核心中间表示,通过引入块(block)和轴属性等概念,在保持表达力的同时提供了丰富的优化空间。理解TensorIR是掌握TVM编译流程的关键,也是进行高效算子开发和优化的基础。
对于想要深入机器学习编译的开发者,建议从简单的TensorIR示例入手,逐步理解其设计哲学和实现细节,这将为后续的优化器开发和硬件适配打下坚实基础。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考