【Triton 教程】Libdevice (tl_extra.libdevice) 函数

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →https://triton.hyper.ai/

Triton 可以调用外部库中的自定义函数。在这个例子中,我们将使用 libdevice 库在张量上应用 asin 函数。请参考以下链接获取关于所有可用 libdevice 函数语义的详细信息:

  • CUDA:https://docs.nvidia.com/cuda/libdevice-users-guide/index.html
  • HIP:https://github.com/ROCm/llvm-project/tree/amd-staging/amd/device-libs/ocml/src

在 libdevice.py 中,我们试图将相同计算但不同数据类型的函数聚合在一起。例如,__nv_asin 和 __nv_asinf 都计算输入的反正弦的主值,但 __nv_asin 适用于 double 类型,而 __nv_asinf 适用于 float 类型。使用 Triton,您可以简单地调用 tl.math.asin。根据输入和输出类型,Triton 会自动选择正确的底层设备函数来调用。

asin 内核

import torch


import triton
import triton.language as tl
from triton.language.extra import libdevice




@triton.jit
def asin_kernel(
    x_ptr,
    y_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    x = libdevice.asin(x)
    tl.store(y_ptr + offsets, x, mask=mask)

使用默认的 libdevice 库路径

可以使用 triton/language/math.py 中编码的默认 libdevice 库路径。

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

Out:

tensor([0.4105, 0.5430, 0.0249, …, 0.0424, 0.5351, 0.8149],
device=‘cuda:0’) tensor([0.4105, 0.5430, 0.0249, …, 0.0424, 0.5351,
0.8149], device=‘cuda:0’) The maximum difference between torch and triton is 2.384185791015625e-07

定制 libdevice 库路径

可以通过将 libdevice 库的路径传递给 asin 内核来定制 libdevice 库的路径。

output_triton = torch.empty_like(x)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

Out:

tensor([0.4105, 0.5430, 0.0249, …, 0.0424, 0.5351, 0.8149],
device=‘cuda:0’) tensor([0.4105, 0.5430, 0.0249, …, 0.0424, 0.5351,
0.8149], device=‘cuda:0’) The maximum difference between torch and triton is 2.384185791015625e-07

Download Jupyter notebook: 07-extern-functions.ipynb

Download Python source code: 07-extern-functions.py

Download zipped: 07-extern-functions.zip

(MCDD) huangjunkai@DESKTOP-C7JRDCS:/mnt/webdav/luoshijun/mamba$ git checkout v1.0.0 error: Your local changes to the following files would be overwritten by checkout: .github/workflows/_build.yml .github/workflows/build.yml .github/workflows/publish.yaml .gitignore LICENSE MANIFEST.in assets/ssd_algorithm.png benchmarks/benchmark_generation_mamba_simple.py csrc/selective_scan/reverse_scan.cuh csrc/selective_scan/selective_scan.cpp csrc/selective_scan/selective_scan_common.h csrc/selective_scan/uninitialized_copy.cuh evals/lm_harness_eval.py mamba_ssm/__init__.py mamba_ssm/distributed/__init__.py mamba_ssm/distributed/distributed_utils.py mamba_ssm/distributed/tensor_parallel.py mamba_ssm/models/config_mamba.py mamba_ssm/models/mixer_seq_simple.py mamba_ssm/modules/block.py mamba_ssm/modules/mamba2.py mamba_ssm/modules/mamba2_simple.py mamba_ssm/modules/mamba_simple.py mamba_ssm/modules/mha.py mamba_ssm/modules/mlp.py mamba_ssm/modules/ssd_minimal.py mamba_ssm/ops/selective_scan_interface.py mamba_ssm/ops/triton/k_activations.py mamba_ssm/ops/triton/layernorm_gated.py mamba_ssm/ops/triton/selective_state_update.py mamba_ssm/ops/triton/ssd_bmm.py mamba_ssm/ops/triton/ssd_chunk_scan.py mamba_ssm/ops/triton/ssd_chunk_state.py mamba_ssm/ops/triton/ssd_combined.py mamba_ssm/ops/triton/ssd_state_passing.py mamba_ssm/utils/generation.py mamba_ssm/utils/torch.py pyproject.toml tests/ops/triton/test_layernorm_gated.py tests/ops/triton/test_selective_state_update.py tests/ops/triton/test_ssd.py tests/test_generation.py usage.md Please commit your changes or stash them before you switch branches. Aborting 怎么解决这个问题
最新发布
09-19
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值