【Triton 教程】triton_language.dot

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →triton.hyper.ai/

triton.language.dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32)

返回 2 个块的矩阵乘积。

这 2 个块必须都是二维或三维的并且有兼容的内部维度。对于三维的块,tl.dot 执行批量矩阵乘积,其中每个块的第一维度代表批量维度。

参数**:**

  • input(标量类型为 {int8,float8_e5m2,float16,bf``loat16,float32} 中的 2D 或 3D 张量)- 第 1 个要相乘的张量。
  • other**(**标量类型为 {int8,float8_e5m2,float16bf``loat16,float32} 中的 2D 或 3D 张量)- 第 2 个要相乘的张量。
  • acc(标量类型为 {int8,float8_e5m2,float16,bf``loat16,float32} 中的 2D 或 3D 张量)- 累加器张量。如果不为 None,则将结果添加到该张量中。
  • input_precision (string*。*对于 nvidia 可用选项为:"tf32","tf32x3","ieee"。默认为 "tf32"。对于 amd 可用选项为 "ieee") - 用于确定如何使用 Tensor Cores 进行 f32 x f32 的计算。如果设备没有 Tensor Cores 或输入不是 dtype f32,则此选项将被忽略。对于具有 Tensor Cores 的设备,默认精度为 tf32。
  • allow_tf32 - 已弃用。如果为 true,则 input_precision 设置为「tf32」。只能指定 input_precision 和 allow_tf32 中的 1 个(即至少 1 个必须为 None)。
import torch import triton import triton.language as tl @triton.jit def fused_bmm_softmax_kernel( # 输入/输出指针 a_ptr, b_ptr, c_ptr, # 张量维度 B, M, N, K, # 张量A步长 (batch, row, col) stride_ab, stride_am, stride_ak, # 张量B步长 (batch, col, col) stride_bb, stride_bk, stride_bn, # 输出C步长 (batch, row, col) stride_cb, stride_cm, stride_cn, # 分块参数 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr ): # 程序ID对应输出矩阵中的位置 (batch, row) batch_id = tl.program_id(0) row_id = tl.program_id(1) # 初始化行最大值为负无穷 row_max = tl.full((), float('-inf'), dtype=tl.float32) # Pass 1: 计算当前行的最大值 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 为当前列块预加载A的部分数据 a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn # 累加部分矩阵乘法结果 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 更新行最大值 max_val = tl.max(accumulator) row_max = tl.maximum(row_max, max_val) # Pass 2: 计算指数和 (exp_sum) exp_sum = tl.zeros((), dtype=tl.float32) for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 重新加载数据(与Pass 1类似) a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算稳定指数值并累加 exp_vals = tl.exp(accumulator - row_max) exp_sum += tl.sum(exp_vals) # Pass 3: 计算归一化结果并写入 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算最终Softmax结果 exp_vals = tl.exp(accumulator - row_max) softmax_output = exp_vals / exp_sum # 写入输出 c_ptrs = c_ptr + batch_id * stride_cb + row_id * stride_cm + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) * stride_cn tl.store(c_ptrs, softmax_output) 封装函数供PyTorch调用 def fused_bmm_softmax(a: torch.Tensor, b: torch.Tensor): B, M, K = a.shape B, K, N = b.shape c = torch.empty((B, M, N), device=a.device, dtype=a.dtype) # 配置分块大小 (需根据硬件调整) BLOCK_SIZE_M = 1 # 每个程序实例处理1行 BLOCK_SIZE_N = 64 # 列方向分块大小 BLOCK_SIZE_K = 32 # K维度分块大小 grid = (B, M) # 每个批次每行一个程序实例 fused_bmm_softmax_kernel[grid]( a, b, c, B, M, N, K, a.stride(0), a.stride(1), a.stride(2), b.stride(0), b.stride(1), b.stride(2), c.stride(0), c.stride(1), c.stride(2), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K ) return c把上面的融合kernel跟据下面的代码的思路写一个配套的代码(完全与融合kernel配套,只是代码思路类似而已):import os import sys import random import itertools import csv from triton import Config import torch import triton import torch_mlu import torch.nn as nn import triton.language as tl os.environ[“TRITON_PRINT_AUTOTUNING”] = “1” def get_config(): row_block = list(range(1, 33)) num_stages = [1, 2, 3, 4] configs = list(itertools.product(row_block, num_stages)) random.shuffle(configs) return configs @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, OUTER_ROW_BLOCK: tl.constexpr, ROW_BLOCK: tl.constexpr, num_stages: tl.constexpr): # The rows of the softmax are independent, so we parallelize across rows. pid = tl.program_id(0) ub = tl.minimum(n_rows, (tl.program_id(0) + 1) * OUTER_ROW_BLOCK) outer_raw_idx = tl.program_id(0) * OUTER_ROW_BLOCK for inner_row in range(outer_raw_idx, ub, ROW_BLOCK): row_offsets = tl.arange(0, ROW_BLOCK) row_mask = (inner_row + row_offsets < ub)[:, None] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + (inner_row + row_offsets) * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] input_ptrs = row_start_ptr[:, None] + col_offsets # Load rows into SRAM row = tl.load(input_ptrs, mask=row_mask).to(tl.float32) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=1)[:, None] # Note that exponentiation in Triton is fast but approximate numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = numerator / denominator softmax_output = softmax_output.to(tl.float16) # Write back output to DRAM output_row_start_ptr = output_ptr + (inner_row + row_offsets) * output_row_stride output_ptrs = output_row_start_ptr[:, None] + col_offsets tl.store(output_ptrs, softmax_output, mask=row_mask) def triton_softmax_perf(configs, x): n_rows, n_cols = x.shape BLOCK_SIZE = n_cols # Allocate output y = torch.empty_like(x) # Enqueue kernel. We split all rows into 48 parts evenly. OUTER_ROW_BLOCK = triton.cdiv(n_rows, 48) grid = lambda META: (48, 1, 1) softmax_kernel[(grid)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, OUTER_ROW_BLOCK=OUTER_ROW_BLOCK, ROW_BLOCK=configs[0], num_stages=configs[1]) return y def benchmark_softmax(configs, M, K, warmup, repeat): # 创建输入数据 a = torch.randn(M, K, dtype=torch.float16, device=‘mlu’).contiguous() trainset, count, maxsampled = [], 0, 20 for i in range(len(configs)): try: # filter invalid settings c_triton = triton_softmax_perf(configs[i], a) except: continue for _ in range(warmup): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() start_event = torch.mlu.Event(enable_timing=True) end_event = torch.mlu.Event(enable_timing=True) start_event.record() for _ in range(repeat): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() end_event.record() torch.mlu.synchronize() trainset.append(list(configs[i])) trainset[count].append(start_event.hardware_time(end_event) / repeat / 1000) print('train data {}: {}'.format(count, trainset[count])) count += 1 if count >= maxsampled: break with open('data/{}_{}.csv'.format(M, K), 'w', newline='') as f: writer = csv.writer(f) writer.writerows(trainset) def main(): micro_shapes = [ (80, 768), (604, 768), (1216, 768), (1968, 768), ] configs = get_config() for M, K in micro_shapes: print('Processing input {} {}'.format(M, K)) benchmark_softmax(configs, M, K, warmup=10, repeat=20) if name==“main”: main()
06-12
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值