导读: DGL kernel中针对Graph的计算几个版本有了不小的变动。
v0.3-0.4使用的是minigun, v0.3和v0.4源码中主要相关部分则是在对应分支dgl/src/kernel目录下。
v0.5中对kernel代码进行了重写,并不再继续使用minigun,可以参见pull 1644。此时DGL将GNN主要抽象为了SPMM和SDDMM两个接口,这里也可以参考DGL v0.5的论文。
后来AWS DGL团队又提出了一个新的基于tvm的kernel FeatGraph并发表在了SC20上,在DGL v0.6中已经部分合入了,在目录dgl/featgraph/。
本文想简单梳理一下DGL各kernel的改动。
ps: v0.5版本后的代码逻辑感觉清晰了很多,更易读一些(完全不会C++的我都能把逻辑流程大概走通)。
yzh119-fg-kernel分支
总体来看,featgraph路子看起来完全和之前手写的并行计算不太一样,但感觉c++部分代码似乎没有完全开源?还是说嵌入到了tvm中?
gspmm.py
源码如下:
""" The compute function and schedules for GSpMM kernels written in TVM. """
import tvm
from tvm import te
from tvm import topi
from tvm.tir import IntImm, decl_buffer
from tvm.topi.utils import prod
from .utils import binary_op_map, reduce_op_map, TargetCode
__all__ = ['gspmm']
def _spmm(out_shp, binary_op, reduce_op,
indptr, indices,
ufeat, efeat):
use_u = binary_op != 'copy_rhs'
use_e = binary_op != 'copy_lhs'
require_arg = reduce_op in ['min', 'max']
n = indptr.shape[0] - 1
def node_func(*args):
row = args[0]
start = indptr[row]
end = indptr[row + 1]
# The following line avoids the possible illegal memory access error
# Related issue: https://github.com/apache/tvm/issues/6596
#length = te.if_then_else(row < n, end - start, 0)
length = end - start
k = te.reduce_axis((0, length), name='k')
eid = k + start
uid = indices[eid]
msg = []
e_val, u_val = None, None
if use_e:
e_val = efeat.__getitem__((eid,) + args[1:])
if require_arg:
msg.append(eid)
if use_u:
u_val = ufeat.__getitem__((uid,) + args[1:])
if require_arg:
msg.append(uid)
msg.append(binary_op_map[binary_op](u_val, e_val))
return reduce_op_map[reduce_op](msg, axis=k)
ret = te.compute(out_shp, node_func, name='out')
if reduce_op == 'sum':
ret = (ret,)
return ret
def _spmm_cuda_general(sched, out):
node_axis = out.op.axis[0]
feat_axis = sched[out].fuse(*out.op.axis[1:])
node_outer, node_inner = sched[out].split(node_axis, factor=1)
sched[out].bind(node_outer, te.thread_axis('blockIdx.x'))
sched[out].bind(node_inner, te.thread_axis('threadIdx.y'))
sched[out].bind(feat_axis, te.thread_axis('threadIdx.x'))
_gspmm_cuda_schedule = {
'general': _spmm_cuda_general,
}
def gspmm(binary_op, reduce_op,
ndim,
indice_type, feat_type,
schedule_type='general',
target='cuda'):
"""
Compile GSPMM kernels using TVM.
"""
num_rows = te.var('num_rows', indice_type)
num_cols = te.var('num_cols', indice_type)
nnz = te.var('nnz', indice_type)
indptr = te.placeholder((num_rows + 1,), indice_type, 'indptr')
indices = te.placeholder((nnz,), indice_type, 'indices')
out_feat_shp = [te.var('do_{}'.format(i), indice_type) for i in range(ndim)]
ufeat_shp = [te.var('du_{}'.format(i), indice_type) for i in range(ndim)]
efeat_shp = [te.var('dv_{}'.format(i), indice_type) for i in range(ndim)]
ufeat = te.placeholder([num_cols] + ufeat_shp, feat_type, 'ufeat')
efeat = te.placeholder([nnz] + efeat_shp, feat_type, 'efeat')
ret = _spmm([num_rows] + out_feat_shp,
binary_op, reduce_op,
indptr, indices, ufeat, efeat)
out = ret[-1]
sched = te.create_schedule(out.op)
if target == 'cuda':
_gspmm_cuda_schedule[schedule_type](sched, out)
else:
raise NotImplementedError("CPU schedule not implemented yet.")
f_input = [indptr, indices]
f_name = '_'.join(str(x) for x in [
'spmm', binary_op, reduce_op,# ndim,
indice_type, feat_type,
schedule_type, target
])
use_u = binary_op != 'copy_rhs'
use_e = binary_op != 'copy_lhs'
if use_u:
f_input.append(ufeat)
if use_e:
f_input.append(efeat)
f_input += ret
u_buffer = decl_buffer(ufeat.shape, ufeat.dtype, name='u_buf',
buffer_type='auto_broadcast')
e_buffer = decl_buffer(efeat.shape, efeat.dtype, name='e_buf',
buffer_type='auto_broadcast')
binds = {ufeat: u_buffer, efeat: e_buffer}
return tvm.lower(sched, f_input, name=f_name, binds=binds)
gsddmm.py
源码如下:
""" The compute function and schedules for GSDDMM kernels written in TVM. """
import tvm
from tvm import te
from tvm import topi
from tvm.tir import IntImm, decl_buffer
from tvm.topi.utils import prod
from .utils import binary_op_map, TargetCode
__all__ = ['gsddmm']
def _sddmm_compute(out_shp, binary_op,
lhs, rhs, get_lhs, get_rhs):
"""Define the SDDMM' TVM compute function. """
reduce_size = lhs.shape[-1] if binary_op == 'dot' else 1
feat_len = prod(out_shp[1:])
feat_len *= reduce_size
if binary_op == 'dot':
k = te.reduce_axis((0, reduce_size), name='k')
def dot_edge_func(*args):
eid = args[0]
lval = lhs.__getitem__((get_lhs(eid),) + args[1: -1] + (k,))
rval = rhs.__getitem__((get_rhs(eid),) + args[1: -1] + (k,))
return te.sum(lval * rval, axis=k)
out = te.compute(out_shp, dot_edge_func, name='out')
else:
def edge_func(*args):
eid = args[0]
lval = lhs.__getitem__((get_lhs(eid),) + args[1:])
rval = rhs.__getitem__((get_rhs(eid),) + args[1:])
return binary_op_map[binary_op](lval, rval)
out = te.compute(out_shp, edge_func, name='out')
return out
def _sddmm_cuda_general(sched, out):
"""Define the SDDMM's TVM general schedule. """
out_len = prod(out.shape[1:])
edge_axis = out.op.axis[0]
feat_axis = sched[out].fuse(*out.op.axis[1:])
#ntx = tvm.autotvm.task.space.get_pow2s(out_len)[-1]
#ntx = 1024 if ntx > 1024 else ntx
#nty = 1024 // ntx
ntx = 32
nty = 32
feat_outer, feat_inner = sched[out].split(feat_axis, factor=ntx)
edge_outer, edge_inner = sched[out].split(edge_axis, factor=nty)
sched[out].bind(feat_inner, te.thread_axis('threadIdx.x'))
sched[out].bind(feat_outer, te.thread_axis('blockIdx.y'))
sched[out].bind(edge_inner, te.thread_axis('threadIdx.y'))
sched[out].bind(edge_outer, te.thread_axis('blockIdx.x'))
def _sddmm_cuda_tree_reduce(sched, out):
"""Define the SDDMM's TVM tree-reduction schedule. """
# TODO(zihao): handle other dimensions.
edge_axis = out.op.axis[0]
reduce_axis = out.op.reduce_axis[0]
# sched[out].bind(reduce_axis, te.thread_axis('threadIdx.x'))
# sched[out].bind(edge_axis, te.thread_axis('blockIdx.x'))
_, red_inner = sched[out].split(reduce_axis, factor=32)
edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
sched[out].bind(red_inner, te.thread_axis('threadIdx.x'))
sched[out].bind(edge_inner, te.thread_axis('threadIdx.y'))
sched[out].bind(edge_outer, te.thread_axis('blockIdx.x'))
_sddmm_cuda_schedule = {
'general': _sddmm_cuda_general,
'tree': _sddmm_cuda_tree_reduce
}
def gsddmm(binary_op,
ndim,
indice_type, feat_type,
lhs_target=TargetCode.SRC, rhs_target=TargetCode.DST,
schedule_type='tree',
target='cuda'):
"""
Compile GSDDMM kernel using TVM.
Parameters
----------
binary_op : str
Type of binary operatiin, could be ``add``, ``sub``, ``mul``,
``div`` or ``dot``.
ndim : int
Dimentionality.
indice_type : str
Type of graph indices, could be ``int32`` or ``int64``.
feat_type : str
Type of features, could be ``float16``/``float32``/``float64``
or ``int32``/``int64``.
lhs_target : TargetCode
Indicates the left-hand-side tensor's target.
rhs_target : TargetCode
Indicates the right-hand-side tensor's target.
schedule_type : str
Specifies the schedule type, could be either ``tree`` or ``general``.
target : str
Indicates where kernels are run, i.e. CPU or GPU.
Returns
-------
IRModule, representing compiled kernel.
"""
num_rows = te.var('num_rows', indice_type)
num_cols = te.var('num_cols', indice_type)
nnz = te.var('nnz', indice_type)
# placeholder for sparse matrix
row = te.placeholder((nnz,), indice_type, 'row')
col = te.placeholder((nnz,), indice_type, 'col')
# placeholder for dense features
def create_placeholder(target, feat_shp, name):
if target == TargetCode.SRC:
return te.placeholder([num_rows] + feat_shp, feat_type, name)
elif target == TargetCode.EDGE:
return te.placeholder([nnz] + feat_shp, feat_type, name)
elif target == TargetCode.DST:
return te.placeholder([num_cols] + feat_shp, feat_type, name)
else:
raise DGLError('Unknown target')
out_feat_shp = [te.var('d_o{}'.format(i), indice_type) for i in range(ndim)]
lhs_feat_shp = [te.var('d_l{}'.format(i), indice_type) for i in range(ndim)]
rhs_feat_shp = [te.var('d_r{}'.format(i), indice_type) for i in range(ndim)]
lhs = create_placeholder(lhs_target, lhs_feat_shp, 'lhs')
rhs = create_placeholder(rhs_target, rhs_feat_shp, 'rhs')
# idx wrapper for corresponding target
target_getter = {
TargetCode.SRC: lambda eid: row[eid],
TargetCode.EDGE: lambda eid: eid,
TargetCode.DST: lambda eid: col[eid]
}
# compute
out = _sddmm_compute([nnz] + out_feat_shp,
binary_op, lhs, rhs,
target_getter[lhs_target], target_getter[rhs_target])
# schedule
sched = te.create_schedule(out.op)
if target == 'cuda':
_sddmm_cuda_schedule[schedule_type](sched, out)
elif target == 'llvm':
raise NotImplementedError('CPU schedule not implemented yet.')
# prepare input
f_input = []
f_input.append(row)
f_input.append(col)
f_name = '_'.join(str(x) for x in [
'sddmm', binary_op, ndim,
indice_type, feat_type,
lhs_target, rhs_target, schedule_type, target])
f_input += [lhs, rhs, out]
# bind autobroadcast buffer
lhs_buffer = decl_buffer(lhs.shape, lhs.dtype, name='lhs_buf',
buffer_type='auto_broadcast')
rhs_buffer = decl_buffer(rhs.shape, rhs.dtype, name='rhs_buf',
buffer_type='auto_broadcast')
binds = {lhs:lhs_buffer, rhs:rhs_buffer}
return tvm.lower(sched, f_input, name=f_name, binds=binds)
而且在 src/array/kernel.cc
中还新注册了基于FeatGraph的算子,但目前应该还没有用上:
// fg-kernel分支 src/array/kernel.cc line 251
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
NDArray lhs = args[1];
NDArray rhs = args[2];
NDArray out = args[3];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
// auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
// const dgl_type_t src_vtype = pair.first;
// const dgl_type_t dst_vtype = pair.second;
// CheckShape(
// {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
// {lhs_target, rhs_target, 1},
// {lhs, rhs, out},
// {"U_data", "E_data", "V_data"});
COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
dgl::featgraph::SDDMM(coo.row.ToDLPack(), coo.col.ToDLPack(),
lhs.ToDLPack(), rhs.ToDLPack(), out.ToDLPack());
});
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SPMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
NDArray ufeat = args[1];
NDArray efeat = args[2];
NDArray out = args[3];
NDArray argu = args[4];
NDArray arge = args[5];
CheckCtx(graph->Context(), {ufeat, efeat, out}, {"ufeat", "efeat", "out"});
CheckContiguous({ufeat, efeat, out}, {"ufeat", "efeat", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
CSRMatrix csc = graph.sptr()->GetCSCMatrix(0);
dgl::featgraph::SPMM(csc.indptr.ToDLPack(), csc.indices.ToDLPack(),
ufeat.ToDLPack(), efeat.ToDLPack(),
out.ToDLPack(), argu.ToDLPack(), arge.ToDLPack());
});