DGL-kernel的变更(3)

本文详细梳理了DGL库中图计算内核从minigun到featgraph的演变过程,包括在v0.3-0.4版本中的minigun,v0.5版本中基于SPMM和SDDMM接口的重写,以及v0.6版本中引入的基于tvm的featgraph。通过对gspmm和gsddmm源码的分析,展示了TVM中实现GSPMM和GSDDMM计算函数和调度的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

导读: DGL kernel中针对Graph的计算几个版本有了不小的变动。
  v0.3-0.4使用的是minigun, v0.3和v0.4源码中主要相关部分则是在对应分支dgl/src/kernel目录下。
  v0.5中对kernel代码进行了重写,并不再继续使用minigun,可以参见pull 1644。此时DGL将GNN主要抽象为了SPMM和SDDMM两个接口,这里也可以参考DGL v0.5的论文
  后来AWS DGL团队又提出了一个新的基于tvm的kernel FeatGraph并发表在了SC20上,在DGL v0.6中已经部分合入了,在目录dgl/featgraph/
  本文想简单梳理一下DGL各kernel的改动。
  ps: v0.5版本后的代码逻辑感觉清晰了很多,更易读一些(完全不会C++的我都能把逻辑流程大概走通)。

yzh119-fg-kernel分支

  总体来看,featgraph路子看起来完全和之前手写的并行计算不太一样,但感觉c++部分代码似乎没有完全开源?还是说嵌入到了tvm中?

  gspmm.py源码如下:

""" The compute function and schedules for GSpMM kernels written in TVM. """
import tvm
from tvm import te
from tvm import topi
from tvm.tir import IntImm, decl_buffer
from tvm.topi.utils import prod
from .utils import binary_op_map, reduce_op_map, TargetCode

__all__ = ['gspmm']


def _spmm(out_shp, binary_op, reduce_op,
          indptr, indices,
          ufeat, efeat):
    use_u = binary_op != 'copy_rhs'
    use_e = binary_op != 'copy_lhs'
    require_arg = reduce_op in ['min', 'max']
    n = indptr.shape[0] - 1
    def node_func(*args):
        row = args[0]
        start = indptr[row]
        end = indptr[row + 1]
        # The following line avoids the possible illegal memory access error
        # Related issue: https://github.com/apache/tvm/issues/6596
        #length = te.if_then_else(row < n, end - start, 0)
        length = end - start
        k = te.reduce_axis((0, length), name='k')
        eid = k + start
        uid = indices[eid]
        msg = []
        e_val, u_val = None, None
        if use_e:
            e_val = efeat.__getitem__((eid,) + args[1:])
            if require_arg:
                msg.append(eid)
        if use_u:
            u_val = ufeat.__getitem__((uid,) + args[1:])
            if require_arg:
                msg.append(uid)
        msg.append(binary_op_map[binary_op](u_val, e_val))
        return reduce_op_map[reduce_op](msg, axis=k)
        
    ret = te.compute(out_shp, node_func, name='out')
    if reduce_op == 'sum':
        ret = (ret,)
    return ret


def _spmm_cuda_general(sched, out):
    node_axis = out.op.axis[0]
    feat_axis = sched[out].fuse(*out.op.axis[1:])
    node_outer, node_inner = sched[out].split(node_axis, factor=1)
    sched[out].bind(node_outer, te.thread_axis('blockIdx.x'))
    sched[out].bind(node_inner, te.thread_axis('threadIdx.y'))
    sched[out].bind(feat_axis, te.thread_axis('threadIdx.x'))


_gspmm_cuda_schedule = {
    'general': _spmm_cuda_general,
}


def gspmm(binary_op, reduce_op,
          ndim,
          indice_type, feat_type,
          schedule_type='general',
          target='cuda'):
    """
    Compile GSPMM kernels using TVM.
    """
    num_rows = te.var('num_rows', indice_type)
    num_cols = te.var('num_cols', indice_type)
    nnz = te.var('nnz', indice_type)

    indptr = te.placeholder((num_rows + 1,), indice_type, 'indptr')
    indices = te.placeholder((nnz,), indice_type, 'indices')
    out_feat_shp = [te.var('do_{}'.format(i), indice_type) for i in range(ndim)]
    ufeat_shp = [te.var('du_{}'.format(i), indice_type) for i in range(ndim)]
    efeat_shp = [te.var('dv_{}'.format(i), indice_type) for i in range(ndim)]

    ufeat = te.placeholder([num_cols] + ufeat_shp, feat_type, 'ufeat')
    efeat = te.placeholder([nnz] + efeat_shp, feat_type, 'efeat')

    ret = _spmm([num_rows] + out_feat_shp,
                binary_op, reduce_op,
                indptr, indices, ufeat, efeat)
    out = ret[-1]
    sched = te.create_schedule(out.op)
    
    if target == 'cuda':
        _gspmm_cuda_schedule[schedule_type](sched, out)
    else:
        raise NotImplementedError("CPU schedule not implemented yet.")
    
    f_input = [indptr, indices]
    f_name = '_'.join(str(x) for x in [
        'spmm', binary_op, reduce_op,# ndim,
        indice_type, feat_type,
        schedule_type, target
    ])

    use_u = binary_op != 'copy_rhs'
    use_e = binary_op != 'copy_lhs'
    if use_u:
        f_input.append(ufeat)
    if use_e:
        f_input.append(efeat)

    f_input += ret
    u_buffer = decl_buffer(ufeat.shape, ufeat.dtype, name='u_buf',
                           buffer_type='auto_broadcast')
    e_buffer = decl_buffer(efeat.shape, efeat.dtype, name='e_buf',
                           buffer_type='auto_broadcast')
    binds = {ufeat: u_buffer, efeat: e_buffer}

    return tvm.lower(sched, f_input, name=f_name, binds=binds)


  gsddmm.py源码如下:

""" The compute function and schedules for GSDDMM kernels written in TVM. """
import tvm
from tvm import te
from tvm import topi
from tvm.tir import IntImm, decl_buffer
from tvm.topi.utils import prod
from .utils import binary_op_map, TargetCode

__all__ = ['gsddmm']


def _sddmm_compute(out_shp, binary_op,
                   lhs, rhs, get_lhs, get_rhs):
    """Define the SDDMM' TVM compute function. """
    reduce_size = lhs.shape[-1] if binary_op == 'dot' else 1
    feat_len = prod(out_shp[1:])
    feat_len *= reduce_size
    if binary_op == 'dot':
        k = te.reduce_axis((0, reduce_size), name='k')
        def dot_edge_func(*args):
            eid = args[0]
            lval = lhs.__getitem__((get_lhs(eid),) + args[1: -1] + (k,))
            rval = rhs.__getitem__((get_rhs(eid),) + args[1: -1] + (k,))
            return te.sum(lval * rval, axis=k)
        out = te.compute(out_shp, dot_edge_func, name='out')
    else:
        def edge_func(*args):
            eid = args[0]
            lval = lhs.__getitem__((get_lhs(eid),) + args[1:])
            rval = rhs.__getitem__((get_rhs(eid),) + args[1:])
            return binary_op_map[binary_op](lval, rval)
        out = te.compute(out_shp, edge_func, name='out')
    return out


def _sddmm_cuda_general(sched, out):
    """Define the SDDMM's TVM general schedule. """
    out_len = prod(out.shape[1:])
    edge_axis = out.op.axis[0]
    feat_axis = sched[out].fuse(*out.op.axis[1:])
    #ntx = tvm.autotvm.task.space.get_pow2s(out_len)[-1]
    #ntx = 1024 if ntx > 1024 else ntx
    #nty = 1024 // ntx
    ntx = 32
    nty = 32
    feat_outer, feat_inner = sched[out].split(feat_axis, factor=ntx)
    edge_outer, edge_inner = sched[out].split(edge_axis, factor=nty)
    sched[out].bind(feat_inner, te.thread_axis('threadIdx.x'))
    sched[out].bind(feat_outer, te.thread_axis('blockIdx.y'))
    sched[out].bind(edge_inner, te.thread_axis('threadIdx.y'))
    sched[out].bind(edge_outer, te.thread_axis('blockIdx.x'))


def _sddmm_cuda_tree_reduce(sched, out):
    """Define the SDDMM's TVM tree-reduction schedule. """
    # TODO(zihao): handle other dimensions.
    edge_axis = out.op.axis[0]
    reduce_axis = out.op.reduce_axis[0]
    # sched[out].bind(reduce_axis, te.thread_axis('threadIdx.x'))
    # sched[out].bind(edge_axis, te.thread_axis('blockIdx.x'))
    _, red_inner = sched[out].split(reduce_axis, factor=32)
    edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
    sched[out].bind(red_inner, te.thread_axis('threadIdx.x'))
    sched[out].bind(edge_inner, te.thread_axis('threadIdx.y'))
    sched[out].bind(edge_outer, te.thread_axis('blockIdx.x'))


_sddmm_cuda_schedule = {
    'general': _sddmm_cuda_general,
    'tree': _sddmm_cuda_tree_reduce
}


def gsddmm(binary_op,
           ndim,
           indice_type, feat_type,
           lhs_target=TargetCode.SRC, rhs_target=TargetCode.DST,
           schedule_type='tree',
           target='cuda'):
    """
    Compile GSDDMM kernel using TVM. 

    Parameters
    ----------
    binary_op : str
        Type of binary operatiin, could be ``add``, ``sub``, ``mul``,
        ``div`` or ``dot``.
    ndim : int
        Dimentionality.
    indice_type : str
        Type of graph indices, could be ``int32`` or ``int64``.
    feat_type : str
        Type of features, could be ``float16``/``float32``/``float64``
        or ``int32``/``int64``.
    lhs_target : TargetCode
        Indicates the left-hand-side tensor's target.
    rhs_target : TargetCode
        Indicates the right-hand-side tensor's target.
    schedule_type : str
        Specifies the schedule type, could be either ``tree`` or ``general``.
    target : str
        Indicates where kernels are run, i.e. CPU or GPU.

    Returns
    -------
    IRModule, representing compiled kernel. 
    """
    num_rows = te.var('num_rows', indice_type)
    num_cols = te.var('num_cols', indice_type)
    nnz = te.var('nnz', indice_type)

    # placeholder for sparse matrix
    row = te.placeholder((nnz,), indice_type, 'row')
    col = te.placeholder((nnz,), indice_type, 'col')

    # placeholder for dense features
    def create_placeholder(target, feat_shp, name):
        if target == TargetCode.SRC:
            return te.placeholder([num_rows] + feat_shp, feat_type, name)
        elif target == TargetCode.EDGE:
            return te.placeholder([nnz] + feat_shp, feat_type, name)
        elif target == TargetCode.DST:
            return te.placeholder([num_cols] + feat_shp, feat_type, name)
        else:
            raise DGLError('Unknown target')

    out_feat_shp = [te.var('d_o{}'.format(i), indice_type) for i in range(ndim)]
    lhs_feat_shp = [te.var('d_l{}'.format(i), indice_type) for i in range(ndim)]
    rhs_feat_shp = [te.var('d_r{}'.format(i), indice_type) for i in range(ndim)]
    lhs = create_placeholder(lhs_target, lhs_feat_shp, 'lhs')
    rhs = create_placeholder(rhs_target, rhs_feat_shp, 'rhs')

    # idx wrapper for corresponding target
    target_getter = {
        TargetCode.SRC: lambda eid: row[eid],
        TargetCode.EDGE: lambda eid: eid,
        TargetCode.DST: lambda eid: col[eid]
    }

    # compute
    out = _sddmm_compute([nnz] + out_feat_shp,
                         binary_op, lhs, rhs,
                         target_getter[lhs_target], target_getter[rhs_target])

    # schedule
    sched = te.create_schedule(out.op)

    if target == 'cuda':
        _sddmm_cuda_schedule[schedule_type](sched, out)
    elif target == 'llvm':
        raise NotImplementedError('CPU schedule not implemented yet.')

    # prepare input
    f_input = []
    f_input.append(row)
    f_input.append(col)
    f_name = '_'.join(str(x) for x in [
        'sddmm', binary_op, ndim,
        indice_type, feat_type,
        lhs_target, rhs_target, schedule_type, target])
    f_input += [lhs, rhs, out]

    # bind autobroadcast buffer
    lhs_buffer = decl_buffer(lhs.shape, lhs.dtype, name='lhs_buf',
                             buffer_type='auto_broadcast')
    rhs_buffer = decl_buffer(rhs.shape, rhs.dtype, name='rhs_buf',
                             buffer_type='auto_broadcast')
    binds = {lhs:lhs_buffer, rhs:rhs_buffer}
    return tvm.lower(sched, f_input, name=f_name, binds=binds)


  而且在 src/array/kernel.cc 中还新注册了基于FeatGraph的算子,但目前应该还没有用上:

// fg-kernel分支 src/array/kernel.cc line 251
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef graph = args[0];
    NDArray lhs = args[1];
    NDArray rhs = args[2];
    NDArray out = args[3];
    CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
    CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
    CHECK_EQ(graph->NumEdgeTypes(), 1);
    // auto pair = graph->meta_graph()->FindEdge(0);  // only one etype in the graph.
    // const dgl_type_t src_vtype = pair.first;
    // const dgl_type_t dst_vtype = pair.second;
    // CheckShape(
    //     {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
    //     {lhs_target, rhs_target, 1},
    //     {lhs, rhs, out},
    //     {"U_data", "E_data", "V_data"});
    COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
    dgl::featgraph::SDDMM(coo.row.ToDLPack(), coo.col.ToDLPack(),
                          lhs.ToDLPack(), rhs.ToDLPack(), out.ToDLPack());
  });

DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SPMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef graph = args[0];
    NDArray ufeat = args[1];
    NDArray efeat = args[2];
    NDArray out = args[3];
    NDArray argu = args[4];
    NDArray arge = args[5];
    CheckCtx(graph->Context(), {ufeat, efeat, out}, {"ufeat", "efeat", "out"});
    CheckContiguous({ufeat, efeat, out}, {"ufeat", "efeat", "out"});
    CHECK_EQ(graph->NumEdgeTypes(), 1);
    CSRMatrix csc = graph.sptr()->GetCSCMatrix(0);
    dgl::featgraph::SPMM(csc.indptr.ToDLPack(), csc.indices.ToDLPack(),
                         ufeat.ToDLPack(), efeat.ToDLPack(),
                         out.ToDLPack(), argu.ToDLPack(), arge.ToDLPack());
  });
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值