TVM源码中涉及到表达式遍历的地方,一般是使用VisitExpr接口进行。这个接口涉及TVM的visitor模式,具体分析可以参考TVM之设计模式解读(一)--visitor模式
1. 基类tvm::relay::ExprFunctor
使用visitor遍历的起点是调用VisitExpr接口。我们看下基类tvm::relay::ExprFunctor中这个方法的代码:
template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
...
using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public:
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
...
}
VisitExpr中调用InitVTable,这个代码展开后:
template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
...
public:
...
private:
static FType InitVTable() {
FType vtable;
vtable.template set_dispatch<ConstantNode>([](const ObjectRef& n, TSelf* self, Args... args) {
return self->VisitExpr_(static_cast<const ConstantNode*>(n.get()), std::forward<Args>(args)...); });;
vtable.template set_dispatch<TupleNode>([](const ObjectRef& n, TSelf* self, Args... args) {
return self->VisitExpr_(static_cast<const TupleNode*>(n.get()), std::forward<Args>(args)...); });;
vtable.template set_dispatch<VarNode>([](const ObjectRef& n, TSel

本文解析了TVM源码中如何使用ExprFunctor和ExprVisitor进行表达式遍历,介绍了visitor模式的应用,以及派生类如ExprMutator的特性和在Codegen内存申请中的实际应用。
最低0.47元/天 解锁文章
1481

被折叠的 条评论
为什么被折叠?



