Hi All,
Getting back to this idea, some more conversations between myself, @tqchen, and @jroesch have discovered a few edge cases where purely non-recursive passes will be much more difficult to implement.
To match some of the ideas about Block Normal Form, we’d like to propose iteration infrastructure that expands the call graph in dataflow regions, but still allows the user to recurse for control flow/scoping/let binding/etc. We’re calling this mixed-mode traversal, and there is a POC available here that mimics the current ExprVisitor/ExprMutator API. Please take a look at the comments in that PR for an explanation of Call Expansion and Mixed-mode traversal.
This POC provides a relatively simple API for converting current passes to support non-recursive behavior, therefore we think it’s valuable to implement in the short term.
Longer term, we’d like to propose this API:
For most passes, we suggest using this API, where traversal is done non-recursively by the PostOrderRewrite function; rewriting logic is handled in a non-traversing ExprRewriter:
/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
*/
Expr PostOrderRewrite(const Expr& input, ExprRewriter* rewriter);
/*! \brief A non-iterating Expression Rewriter
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
* graph rewriting.
*/
class ExprRewriter : private ExprFunctor<Expr(const Expr&, const Expr&)> {
public:
/*! \brief Rewrite a node given the orginal form and the form with modified inputs
*
* Uses ExprFunctor for vtable access.
*
* Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
* able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
*/
Rewrite(const Expr& pre, const Expr& post) {
return this->VisitExpr(pre, post);
}
virtual Expr Rewrite_(const TupleNode* pre, const TupleNode* post);
virtual Expr Rewrite_(const CallNode* pre, const CallNode* post);
virtual Expr Rewrite_(const TupleGetItemNode* pre, const TupleGetItemNode* post);
...
...
private:
Expr VisitExpr(const Expr& pre, const Expr& post) final {
return ExprFunctor::VisitExpr(pre, post);
};
Expr VisitExpr_(const TupleNode* pre, const TupleNode* post) final {
return Rewrite_(pre, post)
};
Expr VisitExpr_(const CallNode* pre, const CallNode* post) final {
return Rewrite_(pre, post)
};
Expr VisitExpr_(const TupleGetItemNode* pre,
const TupleGetItemNode* post) final {
return Rewrite_(pre, post)
};
...
...
}
For passes that need to optimize across scope or perform other more complex behavior, we will provide this API, with similar naming for user Rewrite_ functions, where users can take more direct control over recursion, but still process dataflow regions non-recursively.
/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* Scope Mutator provides the same mixed-mode traversal as DataflowMutator, but provides the
* Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior.
*/
class ScopeMutator : public ::tvm::relay::ExprMutator {
public:
Expr VisitExpr(const Expr& expr) final;
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
/*!
* Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
* able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
*/
virtual Expr Rewrite_(const TupleNode* pre, const TupleNode* post);
virtual Expr Rewrite_(const CallNode* pre, const CallNode* post);
virtual Expr Rewrite_(const TupleGetItemNode* pre, const TupleGetItemNode* post);
protected:
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
template <typename T>
Expr Rewrite(const T* op) {
const T* post = ExprMutator::VisitExpr_(op).as<T>();
return Rewrite_(op, post);
}
virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
};
Thank you for your thoughts!