序
官网的triton例子(中文官方教程)对于新手来说曲线还是过于陡峭了。最近发现有个很好的教程,Triton-Puzzles-Lite , 中文可以参考这个Triton魔法。下面是我自己写的,增加了维度信息,如果你有一些cuda基础,那么上手会很快。
puzzle1
# puzzle 1 contant add, N0 = B0
@triton.jit
def add_kernel(
x_ptr: torch.Tensor, # [N0]
z_ptr: torch.Tensor, # [N0]
N0: int, # 向量长度
B0: tl.constexpr # BLOCK_SIZE
):
offsets = tl.arange(0, B0)
x = tl.load(x_ptr + offsets)
z = x + 10.0
tl.store(z_ptr + offsets, z)
return
puzzle2
# puzzle 2 # 1D block constant add
@triton.jit
def add_mask2_kernel(
x_ptr: torch.Tensor, # [N0]
z_ptr: torch.Tensor, # [N0]
N0: int, # 向量长度
B0: tl.constexpr # BLOCK_SIZE
):
pid = tl.program_id(0)
offsets = pid * B0 + tl.arange(0, B0)
mask = offsets < N0
x = tl.load(x_ptr + offsets, mask = mask)
z = x + 10.0
tl.store(z_ptr + offsets, z, mask = mask)
return
puzzel3
# puzzle 3 # outer vector add
@triton.jit
def add_vec_kernel(
x_ptr: torch.Tensor, # N0
y_ptr: torch.Tensor, # N1
z_ptr: torch.Tensor, # [N0, N1]
N0: int, # x向量长度
N1: int, # y向量长度
B0: tl.constexpr, # BLOCK_X_SIZE, B0 = N0
B1: tl.constexpr, # BLOCK_Y_SIZE, B1 = N1 = B0 = N0
):
offs_x = tl.arange(0, B0)
offs_y = tl.arange(0, B1)
offs_z = offs_y[:, None] * N0 + offs_x[None, :]
x = tl.load(x_ptr + offs_x)
y = tl.load(y_ptr + offs_y)
z = y[:, None] + x[None, :]
tl.store(z_ptr + offs_z, z)
return
puzzle4
# puzzle 4 # outer vector add block
@triton.jit
def add_vec_block_kernel(
x_ptr: torch.Tensor, # [N0]
y_ptr: torch.Tensor, # [N1]
z_ptr: torch.Tensor, # [N1, N0]
N0: int, # x向量长度
N1: int, # y向量长度
B0: tl.constexpr, # BLOCK_SIZE_X
B1: tl.constexpr, # BLOCK_SIZE_Y
):
block_idx = tl.program_id(axis = 0)
block_idy = tl.program_id(axis = 1)
offs_x = block_idx * B0 + tl.arange(0, B0)
offs_y = block_idy * B1 + tl.arange(0, B1)
offs_z = offs_y[:, None] * N0 + offs_x[None, :]
mask_x = offs_x < N0
mask_y = offs_y < N1
mask_z = mask_y[:, None] & mask_x[None, :]
x = tl.load(x_ptr + offs_x, mask = mask_x)
y = tl.load(y_ptr + offs_y, mask = mask_y)
z = y[:, None] + x[None, :]
tl.store(z_ptr + offs_z, mask = mask_z)
return
puzzle5
# puzzle 5 fused outer multiplication
@triton.jit
def mul_relu_block_kernel(
x_ptr: torch.Tensor, # [N0]
y_ptr: torch.Tensor, # [N1]
z_ptr: torch.Tensor, # [N1, N0]
N0: int, # x向量长度
N1: int, # y向量长度
B0: tl.constexpr, # BLOCK_SIZE_X
B1: tl.constexpr, # BLOCK_SIZE_Y
):
block_idx = tl.program_id(axis = 0)
block_idy = tl.program_id(axis = 1)
offs_x = block_idx * B0 + tl.arange(0, B0)
offs_y = block_idy * B1 + tl.arange(0, B1)
offs_z = offs_y[::, None] * N0 + offs_x[None, :]
mask_x = offs_x < N0
mask_y = offs_y < N1
mask_xy = mask_x[None, :] & mask_y[:, None]
x = tl.load(x_ptr + offs_x, mask = mask_x)
y = tl.load(y_ptr + offs_y, mask = mask_y)
z = y[:, None] * x[None, :]
z = tl.maximum(z, 0)
tl.store(z_ptr + offs_z, z, mask = mask_xy)
return
puzzle6
# puzzle 6 fused outer multiplication backward
# dz表示损失函数对z函数的导数,即 dl/dz
@triton.jit
def mul_relu_back_block_kernel(
x_ptr: torch.Tensor, # [N1, N0]
y_ptr: torch.Tensor, # [N1]
dz_ptr: torch.Tensor, # [N1, N0]
dx_ptr: torch.Tensor, # [N1, N0]
N0: int, # x向量长度
N1: int, # y向量长度
B0: tl.constexpr, # BLOCK_SIZE_X
B1: tl.constexpr, # BLOCK_SIZE_Y
):
block_i = tl.program_id(axis = 0)
block_j = tl.program_id(axis = 1)
offs_i = block_i * B0 + tl.arange(0, B0)
offs_j = block_j * B1 + tl.arange(0, B1)
offs_ji = offs_j[:, None] * N0 + offs_i[None, :]
mask_i = offs_i < N0
mask_j = offs_j < N1
mask_ji = offs_j[:, None] & mask_i[None, :]
x = tl.load(x_ptr + offs_ji, mask = mask_ji)
y = tl.load(y_ptr + offs_j, mask = mask_j)
df = tl.where(x * y[:, None] > 0, 1.0, 0.0)
dz = tl.load(dz_ptr + offs_ji, mask = mask_ji)
dxy_x = y[:, None]
dx = dz * df * dxy_x
tl.store(dx_ptr + offs_ji, dx, mask = mask_ji)
return
puzzle7
# puzzle 7 long sum
@triton.jit
def long_sum_kernel(
x_ptr: torch.Tensor, # [N0, T]
z_ptr: torch.Tensor, # [N0]
N0: int, # z向量长度
T: int, # x第二维度长度
B0: tl.constexpr, # GRID_SIZE_Y
B1: tl.constexpr, # BLOCK_SIZE_X
):
pid = tl.program_id(axis = 0)
offs_i = pid * B0 + tl.arange(0, B0)
mask_i = offs_i < N0
z = tl.zeros((N0,), dtype = tl.float32)
for j in range(0, T, B1):
offs_j = tl.arange(j, j + B1)
mask_j = offs_j < T
offs_ij = offs_i[:, None] * T + offs_j[None, :]
mask_ij = mask_i[:, None] & offs_j[None, :]
x = tl.load(x_ptr + offs_ij, mask = mask_ij)
z += tl.sum(x, axis = 1)
tl.store(z_ptr + offs_i, z, mask = mask_i)
return