import torch
import triton
import triton.language as tl
@triton.jit
def fused_bmm_softmax_kernel(
# 输入/输出指针
a_ptr, b_ptr, c_ptr,
# 张量维度
B, M, N, K,
# 张量A步长 (batch, row, col)
stride_ab, stride_am, stride_ak,
# 张量B步长 (batch, col, col)
stride_bb, stride_bk, stride_bn,
# 输出C步长 (batch, row, col)
stride_cb, stride_cm, stride_cn,
# 分块参数
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr
):
# 程序ID对应输出矩阵中的位置 (batch, row)
batch_id = tl.program_id(0)
row_id = tl.program_id(1)
# 初始化行最大值为负无穷
row_max = tl.full((), float('-inf'), dtype=tl.float32)
# Pass 1: 计算当前行的最大值
for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
# 为当前列块预加载A的部分数据
a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \
tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak
b_ptrs = b_ptr + batch_id * stride_bb + \
tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \
(n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn
# 累加部分矩阵乘法结果
accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
partial = tl.dot(a, b)
accumulator += partial
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# 更新行最大值
max_val = tl.max(accumulator)
row_max = tl.maximum(row_max, max_val)
# Pass 2: 计算指数和 (exp_sum)
exp_sum = tl.zeros((), dtype=tl.float32)
for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
# 重新加载数据(与Pass 1类似)
a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \
tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak
b_ptrs = b_ptr + batch_id * stride_bb + \
tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \
(n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn
accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
partial = tl.dot(a, b)
accumulator += partial
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# 计算稳定指数值并累加
exp_vals = tl.exp(accumulator - row_max)
exp_sum += tl.sum(exp_vals)
# Pass 3: 计算归一化结果并写入
for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \
tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak
b_ptrs = b_ptr + batch_id * stride_bb + \
tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \
(n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn
accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
partial = tl.dot(a, b)
accumulator += partial
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# 计算最终Softmax结果
exp_vals = tl.exp(accumulator - row_max)
softmax_output = exp_vals / exp_sum
# 写入输出
c_ptrs = c_ptr + batch_id * stride_cb + row_id * stride_cm + \
(n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) * stride_cn
tl.store(c_ptrs, softmax_output)
# 封装函数供PyTorch调用
def fused_bmm_softmax(a: torch.Tensor, b: torch.Tensor):
B, M, K = a.shape
B, K, N = b.shape
c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
# 配置分块大小 (需根据硬件调整)
BLOCK_SIZE_M = 1 # 每个程序实例处理1行
BLOCK_SIZE_N = 64 # 列方向分块大小
BLOCK_SIZE_K = 32 # K维度分块大小
grid = (B, M) # 每个批次每行一个程序实例
fused_bmm_softmax_kernel[grid](
a, b, c,
B, M, N, K,
a.stride(0), a.stride(1), a.stride(2),
b.stride(0), b.stride(1), b.stride(2),
c.stride(0), c.stride(1), c.stride(2),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K
)
return c把上面的融合kernel替换到下面的代码中:import os
import sys
import random
import itertools
import csv
from triton import Config
import torch
import triton
import torch_mlu
import torch.nn as nn
import triton.language as tl
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
def get_config():
row_block = list(range(1, 33))
num_stages = [1, 2, 3, 4]
configs = list(itertools.product(row_block, num_stages))
random.shuffle(configs)
return configs
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows,
n_cols, BLOCK_SIZE: tl.constexpr, OUTER_ROW_BLOCK: tl.constexpr,
ROW_BLOCK: tl.constexpr, num_stages: tl.constexpr):
# The rows of the softmax are independent, so we parallelize across rows.
pid = tl.program_id(0)
ub = tl.minimum(n_rows, (tl.program_id(0) + 1) * OUTER_ROW_BLOCK)
outer_raw_idx = tl.program_id(0) * OUTER_ROW_BLOCK
for inner_row in range(outer_raw_idx, ub, ROW_BLOCK):
row_offsets = tl.arange(0, ROW_BLOCK)
row_mask = (inner_row + row_offsets < ub)[:, None]
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + (inner_row + row_offsets) * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)[None, :]
input_ptrs = row_start_ptr[:, None] + col_offsets
# Load rows into SRAM
row = tl.load(input_ptrs, mask=row_mask).to(tl.float32)
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=1)[:, None]
# Note that exponentiation in Triton is fast but approximate
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1)[:, None]
softmax_output = numerator / denominator
softmax_output = softmax_output.to(tl.float16)
# Write back output to DRAM
output_row_start_ptr = output_ptr + (inner_row +
row_offsets) * output_row_stride
output_ptrs = output_row_start_ptr[:, None] + col_offsets
tl.store(output_ptrs, softmax_output, mask=row_mask)
def triton_softmax_perf(configs, x):
n_rows, n_cols = x.shape
BLOCK_SIZE = n_cols
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. We split all rows into 48 parts evenly.
OUTER_ROW_BLOCK = triton.cdiv(n_rows, 48)
grid = lambda META: (48, 1, 1)
softmax_kernel[(grid)](y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
OUTER_ROW_BLOCK=OUTER_ROW_BLOCK,
ROW_BLOCK=configs[0],
num_stages=configs[1])
return y
def benchmark_softmax(configs, M, K, warmup, repeat):
# 创建输入数据
a = torch.randn(M, K, dtype=torch.float16, device='mlu').contiguous()
trainset, count, maxsampled = [], 0, 20
for i in range(len(configs)):
try: # filter invalid settings
c_triton = triton_softmax_perf(configs[i], a)
except: continue
for _ in range(warmup):
c_triton = triton_softmax_perf(configs[i], a)
torch.mlu.synchronize()
start_event = torch.mlu.Event(enable_timing=True)
end_event = torch.mlu.Event(enable_timing=True)
start_event.record()
for _ in range(repeat):
c_triton = triton_softmax_perf(configs[i], a)
torch.mlu.synchronize()
end_event.record()
torch.mlu.synchronize()
trainset.append(list(configs[i]))
trainset[count].append(start_event.hardware_time(end_event) / repeat / 1000)
print('train data {}: {}'.format(count, trainset[count]))
count += 1
if count >= maxsampled: break
with open('data/{}_{}.csv'.format(M, K), 'w', newline='') as f:
writer = csv.writer(f)
writer.writerows(trainset)
def main():
micro_shapes = [
(80, 768),
(604, 768),
(1216, 768),
(1968, 768),
]
configs = get_config()
for M, K in micro_shapes:
print('Processing input {} {}'.format(M, K))
benchmark_softmax(configs, M, K, warmup=10, repeat=20)
if __name__=="__main__":
main()