【Triton 教程】融合 Softmax (Fused Softmax)

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →https://triton.hyper.ai/

在本教程中,您将编写一个融合的 softmax 操作,该操作在某些类别的矩阵上比 PyTorch 的原生操作快得多:即那些可以适应 GPU 静态随机存取存储器 (SRAM) 的行。

通过这个过程,您将了解以下内容:

  • 内核融合对于带宽受限操作的优势。
  • Triton 中缩减操作。

动机

用于逐元素加法的自定义 GPU 内核有教学上的价值,但在实践中不能带来很大的进展。
让我们转而考虑一个简单的(数值稳定的)softmax 操作:

import torch


import triton
import triton.language as tl
from triton.runtime import driver




def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch
    使用原生 PyTorch 计算 X 的逐行 softmax


    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    我们减去最大元素以避免溢出。Softmax 对于这种偏移是不变的。
    """
    # read  MN elements ; write M  elements
    # 读取 MN 个元素;写入 M 个元素
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    # 读取 MN + M 个元素;写入 MN 个元素
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    # 读取 MN 个元素;写入 MN 个元素
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    # 读取 MN 个元素;写入 M 个元素
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    # 读取 MN + M 个元素;写入 MN 个元素
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    # 总计:读取 5MN + 2M 个元素;写入 3MN + 2M 个元素
    return ret

直接在 PyTorch 中实现时,计算 y = naive_softmax(x) 需要从 DRAM 中读取 5MN+2M 个元素,并写回 3MN+2M 个元素。

这显然是浪费的;我们更希望有一个自定义的「融合」内核,它只需读取一次 X,并在芯片上进行所有必要的计算。

这样做只需要读写 MN 字节,因此我们可以期望理论上的加速约为 4 倍。

torch.jit.script 标志旨在自动执行这种「内核融合」,但正如我们后面将看到的,它仍不够理想。

计算内核

softmax 内核工作原理如下:每个程序加载输入矩阵 X 的一组行,按程序数量跨步处理,对其进行归一化,并将结果写回输出 Y。

注意,Triton 的一个重要限制是每个块必须具有 2 的幂次数的元素,因此,如果我们要处理任意可能的输入形状,我们需要在内部「填充」每一行,并适当保护内存操作。

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    # 程序起始行
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        # 步长表示我们需要对指针增加多少以推进 1 行
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # 块大小是大于 n_cols 的下一个二的幂,因此我们可以适配
        # row in a single block
        # 单个块中的行
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        # 将行加载到 SRAM 中,使用掩码,因为 BLOCK_SIZE 可能大于 n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        # 为了数值稳定性而减去最大值
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        # 请注意,Triton 中的指数运算速度很快,但是是近似的(例如,类似于 CUDA 中的 __expf)。
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        # 将输出写回 DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

我们可以创建一个辅助函数,为任何给定的输入张量建立内核及其(元)参数队列。

device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape


    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    # 每次循环迭代的块大小是大于 `x` 列数的最小二的幂
    BLOCK_SIZE = triton.next_power_of_2(n_cols)


    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # 另一个技巧是通过增加每行分配的线程数来要求编译器使用更多的线程块 (`num_warps`)
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    # 将在下一个教程中看到如何以更自然的方式自动调整此值,以免自己进行手动启发式处理。
    num_warps = 8


    # Number of software piepling stages.
    # 软件流水线阶段的数量
    num_stages = 4 if SIZE_SMEM > 200000 else 2


    # Allocate output
    # 分配输出空间
    y = torch.empty_like(x)


    # pre-compile kernel to get register usage and compute thread occupancy.
    # 预编译内核以获取寄存器使用情况并计算线程占用情况。
    kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
    if kernel is None:
        kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                       num_stages=num_stages, num_warps=num_warps, grid=(1, ))
        kernel._init_handles()
        n_regs = kernel.n_regs
        size_smem = kernel.metadata.shared
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
        occupancy = min(occupancy, SIZE_SMEM // size_smem)
        num_programs = NUM_SM * occupancy
        kernels[BLOCK_SIZE] = (kernel, num_programs)


    num_programs = min(num_programs, n_rows)


    # Create a number of persistent programs.
    # 创建一些持久化程序。
    kernel[(num_programs, 1, 1)](
        y,
        x,
        x.stride(0),
        y.stride(0),
        n_rows,
        n_cols,
    )
    return y

单元测试

我们将在一个具有不规则行和列数的矩阵上测试我们的内核。

这将验证我们的 Padding 机制是否起作用。

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

结果与预期相同。

基准测试

此处将基于输入矩阵中列数的函数进行基准测试,假设有 4096 行,定义 naive_softmax

然后将其性能与(1)torch.softmax 和(2)上面定义的 naive_softmax 进行比较。

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot 用作图表 x 轴的参数名
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name` `x_name` 的不同可能值
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot 参数名,其值对应于图表中不同线条
        line_vals=['triton', 'torch'],  # possible values for `line_arg`` `line_arg` 的可能值
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines 线条的标签名称
        styles=[('blue', '-'), ('green', '-')],  # line styles 线条的样式
        ylabel="GB/s",  # label name for the y-axis y 轴的标签名称
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot. 图表的名称,也用作保存图表的文件名
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name` `x_names` 和 `y_name` 中未包含的函数参数的值
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
    stream = torch.cuda.Stream()
    torch.cuda.set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)




benchmark.run(show_plots=True, print_data=True)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在上面的图中,我们可以看到:

  • Triton 比 Torch JIT 快 4 倍。这证实了我们对 Torch JIT 在这里没有进行任何融合的怀疑。
  • 除了更易于阅读、理解和维护外,Triton 明显比 torch.softmax 快。

Download Jupyter notebook: 02-fused-softmax.ipynb

Download Python source code: 02-fused-softmax.py

Download zipped: 02-fused-softmax.zip

import torch import triton import triton.language as tl @triton.jit def fused_bmm_softmax_kernel( # 输入/输出指针 a_ptr, b_ptr, c_ptr, # 张量维度 B, M, N, K, # 张量A步长 (batch, row, col) stride_ab, stride_am, stride_ak, # 张量B步长 (batch, col, col) stride_bb, stride_bk, stride_bn, # 输出C步长 (batch, row, col) stride_cb, stride_cm, stride_cn, # 分块参数 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr ): # 程序ID对应输出矩阵中的位置 (batch, row) batch_id = tl.program_id(0) row_id = tl.program_id(1) # 初始化行最大值为负无穷 row_max = tl.full((), float('-inf'), dtype=tl.float32) # Pass 1: 计算当前行的最大值 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 为当前列块预加载A的部分数据 a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn # 累加部分矩阵乘法结果 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 更新行最大值 max_val = tl.max(accumulator) row_max = tl.maximum(row_max, max_val) # Pass 2: 计算指数和 (exp_sum) exp_sum = tl.zeros((), dtype=tl.float32) for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 重新加载数据(与Pass 1类似) a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算稳定指数值并累加 exp_vals = tl.exp(accumulator - row_max) exp_sum += tl.sum(exp_vals) # Pass 3: 计算归一化结果并写入 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算最终Softmax结果 exp_vals = tl.exp(accumulator - row_max) softmax_output = exp_vals / exp_sum # 写入输出 c_ptrs = c_ptr + batch_id * stride_cb + row_id * stride_cm + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) * stride_cn tl.store(c_ptrs, softmax_output) # 封装函数供PyTorch调用 def fused_bmm_softmax(a: torch.Tensor, b: torch.Tensor): B, M, K = a.shape B, K, N = b.shape c = torch.empty((B, M, N), device=a.device, dtype=a.dtype) # 配置分块大小 (需根据硬件调整) BLOCK_SIZE_M = 1 # 每个程序实例处理1行 BLOCK_SIZE_N = 64 # 列方向分块大小 BLOCK_SIZE_K = 32 # K维度分块大小 grid = (B, M) # 每个批次每行一个程序实例 fused_bmm_softmax_kernel[grid]( a, b, c, B, M, N, K, a.stride(0), a.stride(1), a.stride(2), b.stride(0), b.stride(1), b.stride(2), c.stride(0), c.stride(1), c.stride(2), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K ) return c把上面的融合kernel替换到下面的代码中:import os import sys import random import itertools import csv from triton import Config import torch import triton import torch_mlu import torch.nn as nn import triton.language as tl os.environ["TRITON_PRINT_AUTOTUNING"] = "1" def get_config(): row_block = list(range(1, 33)) num_stages = [1, 2, 3, 4] configs = list(itertools.product(row_block, num_stages)) random.shuffle(configs) return configs @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, OUTER_ROW_BLOCK: tl.constexpr, ROW_BLOCK: tl.constexpr, num_stages: tl.constexpr): # The rows of the softmax are independent, so we parallelize across rows. pid = tl.program_id(0) ub = tl.minimum(n_rows, (tl.program_id(0) + 1) * OUTER_ROW_BLOCK) outer_raw_idx = tl.program_id(0) * OUTER_ROW_BLOCK for inner_row in range(outer_raw_idx, ub, ROW_BLOCK): row_offsets = tl.arange(0, ROW_BLOCK) row_mask = (inner_row + row_offsets < ub)[:, None] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + (inner_row + row_offsets) * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] input_ptrs = row_start_ptr[:, None] + col_offsets # Load rows into SRAM row = tl.load(input_ptrs, mask=row_mask).to(tl.float32) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=1)[:, None] # Note that exponentiation in Triton is fast but approximate numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = numerator / denominator softmax_output = softmax_output.to(tl.float16) # Write back output to DRAM output_row_start_ptr = output_ptr + (inner_row + row_offsets) * output_row_stride output_ptrs = output_row_start_ptr[:, None] + col_offsets tl.store(output_ptrs, softmax_output, mask=row_mask) def triton_softmax_perf(configs, x): n_rows, n_cols = x.shape BLOCK_SIZE = n_cols # Allocate output y = torch.empty_like(x) # Enqueue kernel. We split all rows into 48 parts evenly. OUTER_ROW_BLOCK = triton.cdiv(n_rows, 48) grid = lambda META: (48, 1, 1) softmax_kernel[(grid)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, OUTER_ROW_BLOCK=OUTER_ROW_BLOCK, ROW_BLOCK=configs[0], num_stages=configs[1]) return y def benchmark_softmax(configs, M, K, warmup, repeat): # 创建输入数据 a = torch.randn(M, K, dtype=torch.float16, device='mlu').contiguous() trainset, count, maxsampled = [], 0, 20 for i in range(len(configs)): try: # filter invalid settings c_triton = triton_softmax_perf(configs[i], a) except: continue for _ in range(warmup): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() start_event = torch.mlu.Event(enable_timing=True) end_event = torch.mlu.Event(enable_timing=True) start_event.record() for _ in range(repeat): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() end_event.record() torch.mlu.synchronize() trainset.append(list(configs[i])) trainset[count].append(start_event.hardware_time(end_event) / repeat / 1000) print('train data {}: {}'.format(count, trainset[count])) count += 1 if count >= maxsampled: break with open('data/{}_{}.csv'.format(M, K), 'w', newline='') as f: writer = csv.writer(f) writer.writerows(trainset) def main(): micro_shapes = [ (80, 768), (604, 768), (1216, 768), (1968, 768), ] configs = get_config() for M, K in micro_shapes: print('Processing input {} {}'.format(M, K)) benchmark_softmax(configs, M, K, warmup=10, repeat=20) if __name__=="__main__": main()
06-12
import torch import triton import triton.language as tl @triton.jit def fused_bmm_softmax_kernel( # 输入/输出指针 a_ptr, b_ptr, c_ptr, # 张量维度 B, M, N, K, # 张量A步长 (batch, row, col) stride_ab, stride_am, stride_ak, # 张量B步长 (batch, col, col) stride_bb, stride_bk, stride_bn, # 输出C步长 (batch, row, col) stride_cb, stride_cm, stride_cn, # 分块参数 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr ): # 程序ID对应输出矩阵中的位置 (batch, row) batch_id = tl.program_id(0) row_id = tl.program_id(1) # 初始化行最大值为负无穷 row_max = tl.full((), float('-inf'), dtype=tl.float32) # Pass 1: 计算当前行的最大值 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 为当前列块预加载A的部分数据 a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn # 累加部分矩阵乘法结果 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 更新行最大值 max_val = tl.max(accumulator) row_max = tl.maximum(row_max, max_val) # Pass 2: 计算指数和 (exp_sum) exp_sum = tl.zeros((), dtype=tl.float32) for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 重新加载数据(与Pass 1类似) a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算稳定指数值并累加 exp_vals = tl.exp(accumulator - row_max) exp_sum += tl.sum(exp_vals) # Pass 3: 计算归一化结果并写入 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算最终Softmax结果 exp_vals = tl.exp(accumulator - row_max) softmax_output = exp_vals / exp_sum # 写入输出 c_ptrs = c_ptr + batch_id * stride_cb + row_id * stride_cm + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) * stride_cn tl.store(c_ptrs, softmax_output) 封装函数供PyTorch调用 def fused_bmm_softmax(a: torch.Tensor, b: torch.Tensor): B, M, K = a.shape B, K, N = b.shape c = torch.empty((B, M, N), device=a.device, dtype=a.dtype) # 配置分块大小 (需根据硬件调整) BLOCK_SIZE_M = 1 # 每个程序实例处理1行 BLOCK_SIZE_N = 64 # 列方向分块大小 BLOCK_SIZE_K = 32 # K维度分块大小 grid = (B, M) # 每个批次每行一个程序实例 fused_bmm_softmax_kernel[grid]( a, b, c, B, M, N, K, a.stride(0), a.stride(1), a.stride(2), b.stride(0), b.stride(1), b.stride(2), c.stride(0), c.stride(1), c.stride(2), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K ) return c把上面的融合kernel跟据下面的代码的思路写一个配套的代码(完全与融合kernel配套,只是代码思路类似而已):import os import sys import random import itertools import csv from triton import Config import torch import triton import torch_mlu import torch.nn as nn import triton.language as tl os.environ[“TRITON_PRINT_AUTOTUNING”] = “1” def get_config(): row_block = list(range(1, 33)) num_stages = [1, 2, 3, 4] configs = list(itertools.product(row_block, num_stages)) random.shuffle(configs) return configs @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, OUTER_ROW_BLOCK: tl.constexpr, ROW_BLOCK: tl.constexpr, num_stages: tl.constexpr): # The rows of the softmax are independent, so we parallelize across rows. pid = tl.program_id(0) ub = tl.minimum(n_rows, (tl.program_id(0) + 1) * OUTER_ROW_BLOCK) outer_raw_idx = tl.program_id(0) * OUTER_ROW_BLOCK for inner_row in range(outer_raw_idx, ub, ROW_BLOCK): row_offsets = tl.arange(0, ROW_BLOCK) row_mask = (inner_row + row_offsets < ub)[:, None] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + (inner_row + row_offsets) * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] input_ptrs = row_start_ptr[:, None] + col_offsets # Load rows into SRAM row = tl.load(input_ptrs, mask=row_mask).to(tl.float32) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=1)[:, None] # Note that exponentiation in Triton is fast but approximate numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = numerator / denominator softmax_output = softmax_output.to(tl.float16) # Write back output to DRAM output_row_start_ptr = output_ptr + (inner_row + row_offsets) * output_row_stride output_ptrs = output_row_start_ptr[:, None] + col_offsets tl.store(output_ptrs, softmax_output, mask=row_mask) def triton_softmax_perf(configs, x): n_rows, n_cols = x.shape BLOCK_SIZE = n_cols # Allocate output y = torch.empty_like(x) # Enqueue kernel. We split all rows into 48 parts evenly. OUTER_ROW_BLOCK = triton.cdiv(n_rows, 48) grid = lambda META: (48, 1, 1) softmax_kernel[(grid)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, OUTER_ROW_BLOCK=OUTER_ROW_BLOCK, ROW_BLOCK=configs[0], num_stages=configs[1]) return y def benchmark_softmax(configs, M, K, warmup, repeat): # 创建输入数据 a = torch.randn(M, K, dtype=torch.float16, device=‘mlu’).contiguous() trainset, count, maxsampled = [], 0, 20 for i in range(len(configs)): try: # filter invalid settings c_triton = triton_softmax_perf(configs[i], a) except: continue for _ in range(warmup): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() start_event = torch.mlu.Event(enable_timing=True) end_event = torch.mlu.Event(enable_timing=True) start_event.record() for _ in range(repeat): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() end_event.record() torch.mlu.synchronize() trainset.append(list(configs[i])) trainset[count].append(start_event.hardware_time(end_event) / repeat / 1000) print('train data {}: {}'.format(count, trainset[count])) count += 1 if count >= maxsampled: break with open('data/{}_{}.csv'.format(M, K), 'w', newline='') as f: writer = csv.writer(f) writer.writerows(trainset) def main(): micro_shapes = [ (80, 768), (604, 768), (1216, 768), (1968, 768), ] configs = get_config() for M, K in micro_shapes: print('Processing input {} {}'.format(M, K)) benchmark_softmax(configs, M, K, warmup=10, repeat=20) if name==“main”: main()
06-12
<think>好的,用户问的是关于NVIDIA Triton的代码教程。首先,我需要确定用户对Triton的了解程度。Triton是一个推理服务软件,用于部署机器学习模型,支持多种框架,比如TensorFlow、PyTorch等。用户可能想了解如何编写模型服务器的配置文件、部署模型或者进行推理请求。 接下来,我应该提供一个简单的例子,帮助用户快速入门。比如,如何设置模型仓库,编写config.pbtxt文件,以及使用Python客户端发送推理请求。确保代码示例正确,并且步骤清晰。 然后,用户可能的需求包括模型的并行处理、动态批处理配置,或者性能优化。需要考虑这些方面,给出相关的问题建议,帮助用户深入学习。同时,要注意用户可能遇到的常见问题,比如模型格式转换、服务部署错误等,相关问题中应该涵盖这些点。 最后,确保回答符合格式要求,没有使用Markdown,代码块正确,并在结尾列出3-5个相关问题。检查是否有遗漏的部分,比如是否需要安装步骤或者依赖项的说明,但根据用户的问题,可能更关注代码层面的内容,所以暂时先集中在示例和相关问题上。</think>NVIDIA Triton Inference Server 是用于部署 AI 模型的开源推理服务软件。以下是一个基础代码示例: 1. **模型配置文件 (model/config.pbtxt)** ```protobuf name: "example_model" platform: "onnxruntime_onnx" max_batch_size: 8 input [ { name: "input0" data_type: TYPE_FP32 dims: [ 16 ] } ] output [ { name: "output0" data_type: TYPE_FP32 dims: [ 16 ] } ] ``` 2. **Python推理请求示例** ```python import numpy as np import tritonclient.http as httpclient triton_client = httpclient.InferenceServerClient(url="localhost:8000") # 创建输入数据 inputs = [] inputs.append(httpclient.InferInput("input0", [1,16], "FP32")) inputs[0].set_data_from_numpy(np.random.rand(1,16).astype(np.float32)) # 发起推理请求 results = triton_client.infer(model_name="example_model", inputs=inputs) output_data = results.as_numpy("output0") print("推理结果:", output_data.shape) ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值