triton puzzle lite 学习

官网的triton例子中文官方教程)对于新手来说曲线还是过于陡峭了。最近发现有个很好的教程,Triton-Puzzles-Lite , 中文可以参考这个Triton魔法。下面是我自己写的,增加了维度信息,如果你有一些cuda基础,那么上手会很快。

puzzle1

# puzzle 1 contant add, N0 = B0
@triton.jit
def add_kernel(
    x_ptr: torch.Tensor, # [N0]
    z_ptr: torch.Tensor, # [N0]
    N0: int,             # 向量长度
    B0: tl.constexpr     # BLOCK_SIZE
):
    offsets = tl.arange(0, B0)
    x = tl.load(x_ptr + offsets)
    z = x + 10.0
    tl.store(z_ptr + offsets, z)
    return 

puzzle2

# puzzle 2 # 1D block constant add
@triton.jit
def add_mask2_kernel(
    x_ptr: torch.Tensor, # [N0]
    z_ptr: torch.Tensor, # [N0]
    N0: int,             # 向量长度
    B0: tl.constexpr     # BLOCK_SIZE
):
    pid = tl.program_id(0)
    offsets = pid * B0 + tl.arange(0, B0)
    mask = offsets < N0
    x = tl.load(x_ptr + offsets, mask = mask)
    z = x + 10.0
    tl.store(z_ptr + offsets, z, mask = mask)
    return

puzzel3

# puzzle 3 # outer vector add
@triton.jit
def add_vec_kernel(
    x_ptr: torch.Tensor, # N0
    y_ptr: torch.Tensor, # N1
    z_ptr: torch.Tensor, # [N0, N1]
    N0: int,             # x向量长度
    N1: int,             # y向量长度
    B0: tl.constexpr,    # BLOCK_X_SIZE, B0 = N0
    B1: tl.constexpr,    # BLOCK_Y_SIZE, B1 = N1 = B0 = N0
):
    offs_x = tl.arange(0, B0)
    offs_y = tl.arange(0, B1)
    offs_z = offs_y[:, None] * N0 + offs_x[None, :]
    x = tl.load(x_ptr + offs_x)
    y = tl.load(y_ptr + offs_y)
    z = y[:, None] + x[None, :]
    tl.store(z_ptr + offs_z, z)
    return 

puzzle4

# puzzle 4 # outer vector add block
@triton.jit
def add_vec_block_kernel(
    x_ptr: torch.Tensor, # [N0]
    y_ptr: torch.Tensor, # [N1]
    z_ptr: torch.Tensor, # [N1, N0]
    N0: int,             # x向量长度
    N1: int,             # y向量长度            
    B0: tl.constexpr,    # BLOCK_SIZE_X
    B1: tl.constexpr,    # BLOCK_SIZE_Y
):
    block_idx = tl.program_id(axis = 0)
    block_idy = tl.program_id(axis = 1)
    offs_x = block_idx * B0 + tl.arange(0, B0)
    offs_y = block_idy * B1 + tl.arange(0, B1)
    offs_z = offs_y[:, None] * N0 + offs_x[None, :]
    mask_x = offs_x < N0
    mask_y = offs_y < N1
    mask_z = mask_y[:, None] & mask_x[None, :]
    x = tl.load(x_ptr + offs_x, mask = mask_x)
    y = tl.load(y_ptr + offs_y, mask = mask_y)
    z = y[:, None] + x[None, :]
    tl.store(z_ptr + offs_z, mask = mask_z)
    return

puzzle5

# puzzle 5 fused outer multiplication
@triton.jit
def mul_relu_block_kernel(
    x_ptr: torch.Tensor,    # [N0]
    y_ptr: torch.Tensor,    # [N1]
    z_ptr: torch.Tensor,    # [N1, N0]
    N0: int,                # x向量长度
    N1: int,                # y向量长度
    B0: tl.constexpr,       # BLOCK_SIZE_X
    B1: tl.constexpr,       # BLOCK_SIZE_Y
):
    block_idx = tl.program_id(axis = 0)
    block_idy = tl.program_id(axis = 1)
    offs_x = block_idx * B0 + tl.arange(0, B0)
    offs_y = block_idy * B1 + tl.arange(0, B1)
    offs_z = offs_y[::, None] * N0 + offs_x[None, :]
    mask_x = offs_x < N0
    mask_y = offs_y < N1
    mask_xy = mask_x[None, :] & mask_y[:, None]
    x = tl.load(x_ptr + offs_x, mask = mask_x)
    y = tl.load(y_ptr + offs_y, mask = mask_y)
    z = y[:, None] * x[None, :]
    z = tl.maximum(z, 0)
    tl.store(z_ptr + offs_z, z, mask = mask_xy)
    return

puzzle6

# puzzle 6 fused outer multiplication backward
# dz表示损失函数对z函数的导数,即 dl/dz
@triton.jit
def mul_relu_back_block_kernel(
    x_ptr: torch.Tensor,  # [N1, N0]
    y_ptr: torch.Tensor,  # [N1]
    dz_ptr: torch.Tensor, # [N1, N0]
    dx_ptr: torch.Tensor, # [N1, N0]
    N0: int,              # x向量长度
    N1: int,              # y向量长度
    B0: tl.constexpr,     # BLOCK_SIZE_X
    B1: tl.constexpr,     # BLOCK_SIZE_Y
):
    block_i = tl.program_id(axis = 0)
    block_j = tl.program_id(axis = 1)
    offs_i = block_i * B0 + tl.arange(0, B0)
    offs_j = block_j * B1 + tl.arange(0, B1)
    offs_ji = offs_j[:, None] * N0 + offs_i[None, :]
    mask_i = offs_i < N0
    mask_j = offs_j < N1
    mask_ji = offs_j[:, None] & mask_i[None, :]
    x = tl.load(x_ptr + offs_ji, mask = mask_ji)
    y = tl.load(y_ptr + offs_j, mask = mask_j)
    df = tl.where(x * y[:, None] > 0, 1.0, 0.0)
    dz = tl.load(dz_ptr + offs_ji, mask = mask_ji)
    dxy_x = y[:, None]
    dx = dz * df * dxy_x
    tl.store(dx_ptr + offs_ji, dx, mask = mask_ji)
    return 

puzzle7

# puzzle 7 long sum
@triton.jit
def long_sum_kernel(
    x_ptr: torch.Tensor,    # [N0, T]
    z_ptr: torch.Tensor,    # [N0]
    N0: int,                # z向量长度
    T: int,                 # x第二维度长度
    B0: tl.constexpr,       # GRID_SIZE_Y
    B1: tl.constexpr,       # BLOCK_SIZE_X
):
    pid = tl.program_id(axis = 0)
    offs_i = pid * B0 + tl.arange(0, B0)
    mask_i = offs_i < N0
    z = tl.zeros((N0,), dtype = tl.float32)
    for j in range(0, T, B1):
        offs_j = tl.arange(j, j + B1)
        mask_j = offs_j < T
        offs_ij = offs_i[:, None] * T + offs_j[None, :]
        mask_ij = mask_i[:, None] & offs_j[None, :]
        x = tl.load(x_ptr + offs_ij, mask = mask_ij)
        z += tl.sum(x, axis = 1)
    tl.store(z_ptr + offs_i, z, mask = mask_i)
    return
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值