torch._dynamo 理解(2)——Backend

0 概述

TorchDynamo 是一个 Python 级别的即时 (JIT) 编译器,旨在让未修改的 PyTorch 程序运行得更快。它通过 Python Frame Evaluation Hooks(Python 框架评估钩子)来实现这一目标,以便在运行时动态地生成和优化代码。这使得 TorchDynamo 可以有效地处理各种 Python 代码,包括包含控制流(如循环和条件语句)的代码,而无需进行任何修改。

整个 pytorch 的编译栈如下:
在这里插入图片描述

后端整体流程如下:
在这里插入图片描述

以 Triton 为例,首先会做一次 lowering,然后进行调度,最后才会生成 Triton 的 kernel。

1 Loop-level IR

这里的 lowering,使用loop-level IR来表示,其对aten IR的每一句话做解释,并且每次的解析都会与前文联系起来。这一层 IR 的类型有:

  • PointWise
  • Reduction
  • TensorBox
  • MatrixMultiplyAdd
    除此之外,还有一些其他的类型。

这一层处理流程:

  • 对于从前端拿到的aten IR
    在这里插入图片描述

  • 对于上面的每一句运算,都翻译为loop-level IR

    • convert_element_type:
      在这里插入图片描述

    • amax
      在这里插入图片描述

      这里将计算的结果存储到buf0

    • sub:
      在这里插入图片描述

      由于amax将结果存储到buf0中,因此这里才能从buf0中直接 load 进来

    • exp:
      在这里插入图片描述

      如果上一条 IR 是pointwise的话,那么就会和这一次的进行归约,例如这里,只是在sub的 IR 上加上了tmp4 = exp(tmp3)并将 return 改为了tmp4
      这一层的 pass 会对aten IR的每一句话进行解析,并且每次的解析都会与前文联系起来,最终得到一个归约的loop-level IR

2 Schedule

一下面的代码为例:

if __name__ == '__main__':
    model = nn.Sequential(
        nn.Conv2d(16, 32, 3),
        nn.BatchNorm2d(32),
        nn.ReLU(),
    ).cuda()
    model = torch.compile(model)
    x = torch.randn((2, 16, 8, 8), requires_grad=True, device="cuda")
    model(x)

其在loop-level层构建出 11 个缓冲区。随后,对这些缓冲区进行 schedule,内容包括:
在这里插入图片描述

这里有些缓冲区启用了 Reduction,也就是说这里的归约是对于缓冲区而言的。将这些缓冲区放在一起,生成一个 kernel ,而其他的缓冲区,则单独生成自己的 kernel (注意这里的 kernel 是指 triton 的 kernel,实际上我们可以认为是一个函数)。只有 reduction 的 kernel 中会出现循环语句,若只是 pointwise 的计算,则不会生成循环

3 Triton Kernel

最后就是 triton kernel 的生成,其采取的策略是:

  • 首先生成 load 语句
  • 生成 compute 语句
  • 生成 store 语句
  • 组合三种语句为一个 kernel
  • 组合所有 kernel 与一个 call 函数和 main 模块在一起为一个 .py 文件

上述例子生成的文件如下:

from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


triton__0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
        xnumel = 2304
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
    x3 = xindex
    x1 = (xindex // 36) % 32
    tmp0 = tl.load(in_out_ptr0 + (x3), xmask)
    tmp1 = tl.load(in_ptr0 + (x1), xmask)	
    tmp2 = tmp0 + tmp1
    tl.store(in_out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
''')


triton__1 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor

@reduction(
    size_hints=[32, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 32
    rnumel = 72
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex % 36
        r2 = (rindex // 36)
        tmp0 = tl.load(in_ptr0 + (r1 + (36*x0) + (1152*r2)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1 = tl.where(rmask & xmask, _tmp1 + tmp0, _tmp1)
    tmp1 = tl.sum(_tmp1, 1)[:, None]
    tmp6 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = 72.0
    tmp3 = tmp1 / tmp2
    tmp4 = 0.1
    tmp5 = tmp3 * tmp4
    tmp7 = 0.9
    tmp8 = tmp6 * tmp7
    tmp9 = tmp5 + tmp8
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp3, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp9, xmask)
    _tmp13 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex % 36
        r2 = (rindex // 36)
        tmp10 = tl.load(in_ptr0 + (r1 + (36*x0) + (1152*r2)), rmask & xmask, eviction_policy='evict_last', other=0)
        tmp11 = tmp10 - tmp3
        tmp12 = tmp11 * tmp11
        _tmp13 = tl.where(rmask & xmask, _tmp13 + tmp12, _tmp13)
    tmp13 = tl.sum(_tmp13, 1)[:, None]
    tl.store(out_ptr1 + x0, tmp13, xmask)
    tmp23 = tl.load(in_ptr2 + (x0), xmask)
    tmp14 = 72.0
    tmp15 = tmp13 / tmp14
    tmp16 = 1e-05
    tmp17 = tmp15 + tmp16
    tmp18 = tl.libdevice.rsqrt(tmp17)
    tmp19 = 1.0140845070422535
    tmp20 = tmp15 * tmp19
    tmp21 = 0.1
    tmp22 = tmp20 * tmp21
    tmp24 = 0.9
    tmp25 = tmp23 * tmp24
    tmp26 = tmp22 + tmp25
    tl.store(out_ptr2 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp18, xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp26, xmask)
''')


triton__2 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*i1', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2304
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x3 = xindex
    x1 = (xindex // 36) % 32
    tmp0 = tl.load(in_ptr0 + (x3), xmask)
    tmp1 = tl.load(in_ptr1 + (x1), xmask)
    tmp3 = tl.load(in_ptr2 + (x1), xmask)
    tmp10 = tl.load(in_ptr3 + (x1), xmask)
    tmp12 = tl.load(in_ptr4 + (x1), xmask)
    tmp2 = tmp0 - tmp1
    tmp4 = 72.0
    tmp5 = tmp3 / tmp4
    tmp6 = 1e-05
    tmp7 = tmp5 + tmp6
    tmp8 = tl.libdevice.rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp11 = tmp9 * tmp10
    tmp13 = tmp11 + tmp12
    tmp14 = tl.where(0 != 0, 0, tl.where(0 > tmp13, 0, tmp13))
    tmp15 = 0.0
    tmp16 = tmp14 <= tmp15
    tl.store(out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp14, xmask)
    tl.store(out_ptr1 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp16, xmask)
''')


triton__3 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[1], filename=__file__, meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0_load = tl.load(in_ptr0 + (0))
    tmp0 = tl.broadcast_to(tmp0_load, [XBLOCK])
    tmp1 = 1
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), tmp2, None)
''')


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8 = args
    args.clear()
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = aten.convolution(primals_8, primals_1, None, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
        assert_size_stride(buf0, (2, 32, 6, 6), (1152, 36, 6, 1))
        buf1 = buf0; del buf0  # reuse
        stream0 = get_cuda_stream(0)
        triton__0.run(buf1, primals_2, 2304, grid=grid(2304), stream=stream0)
        del primals_2
        buf2 = empty_strided((1, 32, 1, 1), (32, 1, 32, 32), device='cuda', dtype=torch.float32)
        buf3 = buf2; del buf2  # reuse
        buf6 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
        buf4 = empty_strided((1, 32, 1, 1), (32, 1, 32, 32), device='cuda', dtype=torch.float32)
        buf5 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
        buf7 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
        triton__1.run(buf3, buf1, primals_5, primals_6, buf6, buf4, buf5, buf7, 32, 72, grid=grid(32), stream=stream0)
        del primals_5
        del primals_6
        buf8 = empty_strided((2, 32, 6, 6), (1152, 36, 6, 1), device='cuda', dtype=torch.float32)
        buf9 = empty_strided((2, 32, 6, 6), (1152, 36, 6, 1), device='cuda', dtype=torch.bool)
        triton__2.run(buf1, buf3, buf4, primals_3, primals_4, buf8, buf9, 2304, grid=grid(2304), stream=stream0)
        del buf4
        del primals_4
        buf10 = empty_strided((), (), device='cuda', dtype=torch.int64)
        triton__3.run(primals_7, buf10, 1, grid=grid(1), stream=stream0)
        del primals_7
        return (buf6, buf7, buf10, buf8, primals_1, primals_3, primals_8, buf1, buf5, buf9, as_strided(buf3, (1, 32, 1, 1), (32, 1, 1, 1)), )


if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((32, 16, 3, 3), (144, 9, 3, 1), device='cuda:0', dtype=torch.float32)
    primals_2 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_3 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_4 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_5 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_6 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_7 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
    primals_8 = rand_strided((2, 16, 8, 8), (1024, 64, 8, 1), device='cuda:0', dtype=torch.float32)
    print_performance(lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8]))

此文件生成在/tmp文件夹中,后缀为py ,后续直接运行此文件,可得到 performace 的值,同样,也可在运行中捕获到运算的值。

4 loop-level IR --> triton kernel

通过数据结构GraphLowering的方法run(*example_input)也就是一个Fake Tensor来生成 triton kernel:


Graph ID : 0 

Input : {
    'primals_1': TensorBox(StorageBox(
  InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1]))
)), 
    'primals_2': TensorBox(StorageBox(
  InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_3': TensorBox(StorageBox(
  InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_4': TensorBox(StorageBox(
  InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_5': TensorBox(StorageBox(
  InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_6': TensorBox(StorageBox(
  InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_7': TensorBox(StorageBox(
  InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int64, size=[], stride=[]))
)), 
    'primals_8': TensorBox(StorageBox(
  InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))
))} 

Origin Input : {
    'primals_1': InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1])), 
    'primals_2': InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_3': InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_4': InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_5': InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_6': InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
	'primals_7': InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int64, size=[], stride=[])), 
    'primals_8': InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))} 


Output : [
    StorageBox(ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf3, i0)
    tmp1 = constant(0.1, torch.float32)
    tmp2 = tmp0 * tmp1
    tmp3 = load(primals_5, i0)
    tmp4 = constant(0.9, torch.float32)
    tmp5 = tmp3 * tmp4
    tmp6 = tmp2 + tmp5
    return tmp6
    ,
    ranges=[32],
    origins={add_2}
  ))
), 
	StorageBox(ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.float32, size=(32,), stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf4, i0)
    tmp1 = index_expr(72, torch.float32)
    tmp2 = tmp0 / tmp1
    tmp3 = constant(1.0140845070422535, torch.float32)
    tmp4 = tmp2 * tmp3
    tmp5 = constant(0.1, torch.float32)
    tmp6 = tmp4 * tmp5
    tmp7 = load(primals_6, i0)
    tmp8 = constant(0.9, torch.float32)
    tmp9 = tmp7 * tmp8
    tmp10 = tmp6 + tmp9
    return tmp10
    ,
    ranges=(32,),
    origins={add_3}
  ))
), 
	StorageBox(ComputedBuffer(name='buf10', layout=FixedLayout('cuda', torch.int64, size=[], stride=[]), data=Pointwise(
    'cuda',
    torch.int64,
    tmp0 = load(primals_7, 0)
    tmp1 = constant(1, torch.int64)
    tmp2 = tmp0 + tmp1
    return tmp2
    ,
    ranges=[],
    origins={primals_7, clone_2, add}
  ))
), 
	StorageBox(ComputedBuffer(name='buf8', layout=FixedLayout('cuda', torch.float32, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf1, i3 + 6 * i2 + 36 * i1 + 1152 * i0)assembly
    tmp1 = load(buf3, i1)
    tmp2 = tmp0 - tmp1
    tmp3 = load(buf4, i1)
    tmp4 = index_expr(72, torch.float32)
    tmp5 = tmp3 / tmp4
    tmp6 = constant(1e-05, torch.float32)
    tmp7 = tmp5 + tmp6
    tmp8 = rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp10 = load(primals_3, i1)
    tmp11 = tmp9 * tmp10
    tmp12 = load(primals_4, i1)
    tmp13 = tmp11 + tmp12
    tmp14 = relu(tmp13)
    return tmp14
    ,
    ranges=[2, 32, 6, 6],
    origins={relu}
  ))
), 
	StorageBox(InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1]))
), 
	StorageBox(InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
), 
	StorageBox(InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))
), 
	StorageBox(ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf0, i3 + 6 * i2 + 36 * i1 + 1152 * i0)
    tmp1 = load(primals_2, i1)
    tmp2 = tmp0 + tmp1
    return tmp2
    ,
    ranges=[2, 32, 6, 6],
    origins={convolution}
  ))
), 
	StorageBox(ComputedBuffer(name='buf5', layout=FixedLayout('cuda', torch.float32, size=(32,), stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf4, i0)
    tmp1 = index_expr(72, torch.float32)
    tmp2 = tmp0 / tmp1
    tmp3 = constant(1e-05, torch.float32)
    tmp4 = tmp2 + tmp3
    tmp5 = rsqrt(tmp4)
    return tmp5
    ,
    ranges=(32,),
    origins={squeeze_1}
  ))
), 
	StorageBox(ComputedBuffer(name='buf9', layout=FixedLayout('cuda', torch.bool, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
    'cuda',
    torch.bool,
    tmp0 = load(buf8, i3 + 6 * i2 + 36 * i1 + 1152 * i0)
    tmp1 = constant(0, torch.float32)
    tmp2 = tmp0 <= tmp1
    return tmp2
    ,
    ranges=[2, 32, 6, 6],
    origins={le}
  ))
), 
ReinterpretView(StorageBox(ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.float32, size=[1, 32, 1, 1], stride=[32, 1, 32, 32]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf2, i1)
      tmp1 = index_expr(72, torch.float32)
      tmp2 = tmp0 / tmp1
      return tmp2
      ,
      ranges=[1, 32, 1, 1],
      origins={convolution, var_mean}
    ))
  ),
  FixedLayout('cuda', torch.float32, size=[1, 32, 1, 1], stride=[32, 1, 1, 1]),
  no origins?
)]

5 调度的目的

  • 调度的目的:由于在前面已经进行了 decompose (一般在转为 aten 算子的时候就已经完成了),因此这里的目的是为了调整 buff 的次序,也就是调度内存,以优化内存访问的效率。

6 aten IR --> loop-level IR

aten IRloop-level IRtorch/_inductor/compile_fx.py 中 #179 完成的,其中,输入的 gm 中存储的 code 为:

def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8):
    clone = torch.ops.aten.clone.default(primals_5);  primals_5 = None
    clone_1 = torch.ops.aten.clone.default(primals_6);  primals_6 = None
    clone_2 = torch.ops.aten.clone.default(primals_7);  primals_7 = None
    convolution = torch.ops.aten.convolution.default(primals_8, primals_1, primals_2, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  primals_2 = None
    add = torch.ops.aten.add.Tensor(clone_2, 1);  clone_2 = None
    var_mean = torch.ops.aten.var_mean.correction(convolution, [0, 2, 3], correction = 0, keepdim = True)
    getitem = var_mean[0]
    getitem_1 = var_mean[1];  var_mean = None
    add_1 = torch.ops.aten.add.Tensor(getitem, 1e-05)
    rsqrt = torch.ops.aten.rsqrt.default(add_1);  add_1 = None
    sub = torch.ops.aten.sub.Tensor(convolution, getitem_1)
    mul = torch.ops.aten.mul.Tensor(sub, rsqrt);  sub = None
    squeeze = torch.ops.aten.squeeze.dims(getitem_1, [0, 2, 3]);  getitem_1 = None
    squeeze_1 = torch.ops.aten.squeeze.dims(rsqrt, [0, 2, 3]);  rsqrt = None
    mul_1 = torch.ops.aten.mul.Tensor(squeeze, 0.1)
    mul_2 = torch.ops.aten.mul.Tensor(clone, 0.9);  clone = None
    add_2 = torch.ops.aten.add.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
    squeeze_2 = torch.ops.aten.squeeze.dims(getitem, [0, 2, 3]);  getitem = None
    mul_3 = torch.ops.aten.mul.Tensor(squeeze_2, 1.0140845070422535);  squeeze_2 = None
    mul_4 = torch.ops.aten.mul.Tensor(mul_3, 0.1);  mul_3 = None
    mul_5 = torch.ops.aten.mul.Tensor(clone_1, 0.9);  clone_1 = None
    add_3 = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
    unsqueeze = torch.ops.aten.unsqueeze.default(primals_3, -1)
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = None
    unsqueeze_2 = torch.ops.aten.unsqueeze.default(primals_4, -1);  primals_4 = None
    unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1);  unsqueeze_2 = None
    mul_6 = torch.ops.aten.mul.Tensor(mul, unsqueeze_1);  mul = unsqueeze_1 = None
    add_4 = torch.ops.aten.add.Tensor(mul_6, unsqueeze_3);  mul_6 = unsqueeze_3 = None
    relu = torch.ops.aten.relu.default(add_4);  add_4 = None
    le = torch.ops.aten.le.Scalar(relu, 0)
    unsqueeze_4 = torch.ops.aten.unsqueeze.default(squeeze, 0);  squeeze = None
    unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2);  unsqueeze_4 = None
    unsqueeze_6 = torch.ops.aten.unsqueeze.default(unsqueeze_5, 3);  unsqueeze_5 = None
    return [add_2, add_3, add, relu, primals_1, primals_3, primals_8, convolution, squeeze_1, le, unsqueeze_6]

得到的loop-level IR会通过下一行的compile_to_fn()进行到 triton 的转化,生成的 triron 代码会存储在/tmp/目录下的.py文件中。返回的值是一个函数compiled_fn,其__module__变量存储着上述的文件路径。

7 GraphLowering --> Triton kernel
  • 调用graph.compile_to_fn()
  • 这个函数会先去调用 graph 中的compiler_to_module(),对此返回值,取出其 call 属性并返回
  • 对于compiler_to_module(),首先调用self.codegen()来生成 triton 代码(返回一个 py 文件),随后将此代码重命名后返回
  • 在 codegen 中,首先调用了self.init_wrapper_code(),此函数只是检查是否需要使用 cpp 包装,一般都不需要,于是实例化了一个WrapperCodeGen()的对象并返回
  • 对 graph 中的 scheduler 进行实例化,调度的对象为loop-level IR中构造出的东西,实际上可以视为计算节点
    • 实例化的过程:
      • 声明一个空的 node 列表,用于新的构造
      • 拿到后续计算所依赖的缓冲区名称
      • 遍历传入的参数列表,这里就是在之前传入的列表等,对于列表中的每一个元素,做如下操作:
        • 查看此 node 是否存在入度(也就是数据是从什么地方来的,一般为缓冲区名称)
        • 对 node 的类型进行查看,在这里由于传入的节点均为 buffer ,因此不会进入is_no_op函数。接着,判断是否为 ComputedBufferTemplateBuffer ,其中 TemplateBuffer 给出的解释为 Represents a Triton (in the futurue other type) of template operator that we can fuse an epilogue onto.(显然,对于后续的 ComputedBuffer 都会进入这一条分支,并执行 self.get_backend(node.get_device()).group_fn )。对于卷积而言,在这里定义为 ExternKernel。
        • 将刚才生成的 node 添加到最开始创建的 node 列表中去
        • 做完这部分,接着进行死节点消除与节点融合
      • 完成调度后,接着就开始直接生成内核,注意,如果是特殊的算子(例如卷积)是不会被翻译为 triton 的,而是直接生成 aten ,否则,我们会进入 codegen_kernel 阶段。
/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/utils/hub.py:105: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead. warnings.warn( The following values were not passed to `accelerate launch` and had defaults used instead: More than one GPU was found, enabling multi-GPU training. If this was unintended please pass in `--num_processes=1`. `--num_machines` was set to a value of `1` `--mixed_precision` was set to a value of `'no'` `--dynamo_backend` was set to a value of `'no'` To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`. /home/ustc/anaconda3/lib/python3.12/site-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning) /home/ustc/anaconda3/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`. WeightNorm.apply(module, name, dim) /home/ustc/anaconda3/lib/python3.12/site-packages/transformers/utils/hub.py:105: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead. warnings.warn( [rank0]: Traceback (most recent call last): [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 92, in _call_target [rank0]: return _target_(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/桌面/seed-vc-main-v2/modules/astral_quantization/default_model.py", line 22, in __init__ [rank0]: self.tokenizer = WhisperProcessor.from_pretrained( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/processing_utils.py", line 1079, in from_pretrained [rank0]: args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/processing_utils.py", line 1143, in _get_arguments_from_pretrained [rank0]: args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/feature_extraction_utils.py", line 384, in from_pretrained [rank0]: feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/feature_extraction_utils.py", line 510, in get_feature_extractor_dict [rank0]: resolved_feature_extractor_file = cached_file( [rank0]: ^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/utils/hub.py", line 266, in cached_file [rank0]: file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/transformers/utils/hub.py", line 381, in cached_files [rank0]: raise OSError( [rank0]: OSError: ./checkpoints/hf_cache does not appear to have a file named preprocessor_config.json. Checkout 'https://huggingface.co/./checkpoints/hf_cache/tree/main' for available files. [rank0]: The above exception was the direct cause of the following exception: [rank0]: Traceback (most recent call last): [rank0]: File "/home/ustc/桌面/seed-vc-main-v2/train_v2.py", line 345, in <module> [rank0]: main(args) [rank0]: File "/home/ustc/桌面/seed-vc-main-v2/train_v2.py", line 314, in main [rank0]: trainer = Trainer( [rank0]: ^^^^^^^^ [rank0]: File "/home/ustc/桌面/seed-vc-main-v2/train_v2.py", line 75, in __init__ [rank0]: self._init_models(train_cfm=train_cfm, train_ar=train_ar) [rank0]: File "/home/ustc/桌面/seed-vc-main-v2/train_v2.py", line 106, in _init_models [rank0]: self._init_main_model(train_cfm=train_cfm, train_ar=train_ar) [rank0]: File "/home/ustc/桌面/seed-vc-main-v2/train_v2.py", line 116, in _init_main_model [rank0]: self.model = hydra.utils.instantiate(cfg).to(self.device) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 226, in instantiate [rank0]: return instantiate_node( [rank0]: ^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 342, in instantiate_node [rank0]: value = instantiate_node( [rank0]: ^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 347, in instantiate_node [rank0]: return _call_target(_target_, partial, args, kwargs, full_key) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/ustc/anaconda3/lib/python3.12/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 97, in _call_target [rank0]: raise InstantiationException(msg) from e [rank0]: hydra.errors.InstantiationException: Error in call to target 'modules.astral_quantization.default_model.AstralQuantizer': [rank0]: OSError("./checkpoints/hf_cache does not appear to have a file named preprocessor_config.json. Checkout 'https://huggingface.co/./checkpoints/hf_cache/tree/main' for available files.") [rank0]: full_key: content_extractor_narrow [rank0]:[W514 21:22:23.005946238 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator()) W0514 21:22:24.285000 132883023812096 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 3125405 closing signal SIGTERM W0514 21:22:24.285000 132883023812096 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 3125406 closing signal SIGTERM W0514 21:22:24.285000 132883023812096 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 3125407 closing signal SIGTERM E0514 21:22:24.500000 132883023812096 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 3125404) of binary: /home/ustc/anaconda3/bin/python Traceback (most recent call last): File "/home/ustc/anaconda3/bin/accelerate", line 8, in <module> sys.exit(main()) ^^^^^^ File "/home/ustc/anaconda3/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main args.func(args) File "/home/ustc/anaconda3/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1204, in launch_command multi_gpu_launcher(args) File "/home/ustc/anaconda3/lib/python3.12/site-packages/accelerate/commands/launch.py", line 825, in multi_gpu_launcher distrib_run.run(args) File "/home/ustc/anaconda3/lib/python3.12/site-packages/torch/distributed/run.py", line 892, in run elastic_launch( File "/home/ustc/anaconda3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 133, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ustc/anaconda3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ train_v2.py FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2025-05-14_21:22:24 host : ustc-SYS-740GP-TNRT rank : 0 (local_rank: 0) exitcode : 1 (pid: 3125404) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ 分析原因
最新发布
05-15
### 解决 Transformers 和 accelerate 中的警告和错误 在使用 `Transformers` 和 `accelerate` 库时,可能会遇到多种警告和错误。以下是针对提到的具体问题的解决方案: #### 1. **TRANSFORMERS_CACHE 被废弃** `TRANSFORMERS_CACHE` 已被标记为过期 (deprecated),建议改用环境变量 `HF_HOME` 来指定缓存路径[^1]。可以通过设置以下环境变量来替代旧的行为: ```bash export HF_HOME=/path/to/cache/directory ``` 如果需要在 Python 脚本中动态设置此变量,则可以这样做: ```python import os os.environ["HF_HOME"] = "/path/to/cache/directory" ``` #### 2. **缺少 ffmpeg 或 avconv 导致的错误** 某些模型依赖于多媒体处理工具(如 `ffmpeg` 或 `avconv`)。当这些工具不可用时,运行代码可能抛出 `OSError`。要解决这个问题,请先确认已安装所需的工具。 对于基于 Linux 的系统,可通过包管理器安装 `ffmpeg`: ```bash sudo apt update && sudo apt install ffmpeg ``` 或者,在 macOS 上通过 Homebrew 安装: ```bash brew install ffmpeg ``` 验证安装成功后,重新启动脚本即可消除相关错误。 #### 3. **weight_norm 报废提示** `torch.nn.utils.weight_norm` 函数已被标记为弃用,并将在未来的 PyTorch 版本中移除。推荐替换为更现代的方法——例如直接初始化权重矩阵并应用正则化技术。具体来说,可以在定义网络层时显式地实现类似的规范化逻辑。 假设您正在迁移现有代码库中的某一层结构,可以用如下方式重构它: ```python import torch.nn as nn class CustomLayer(nn.Module): def __init__(self, input_dim, output_dim): super(CustomLayer, self).__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): return self.linear(x) ``` 这样不仅消除了警告消息,还提高了代码可维护性和性能表现。 #### 4. **加载 preprocessor_config.json 文件失败** 此类问题是由于目标目录下的必要文件丢失引起的。通常情况下,这涉及到预处理器配置文件 (`preprocessor_config.json`) 缺失的情况。为了修复这一情况,请确保所有必需资源均已正确下载至本地存储位置。 以 Florence 模型为例,其正常工作至少需要两个核心组件:`pytorch_model.bin` 和 `config.json` 存在于同一父级文件夹内[^1]。因此,请仔细核对您的 models 目录树是否完整无缺;若有任何遗漏项,则需补充相应数据集链接地址所提供的全部资产副本。 最后一步操作指南如下所示: ```bash mkdir -p /your/path/models/Florence/ wget https://example.com/pytorch_model.bin -P /your/path/models/Florence/ wget https://example.com/config.json -P /your/path/models/Florence/ ``` 完成以上步骤之后再次尝试执行程序应该能够顺利解决问题。 --- ### 总结 综上所述,分别采取措施更新环境变量设定、补全外部依赖软件包、调整陈旧 API 呼叫形式以及保障资料齐全这几方面入手,便能有效应对当前所面临的一系列挑战。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值