Performing Relay Passes Non-Recursively

Recently, a number of more complex models (LSTMs in my case) have run into segmentation faults in Relay optimization passes. This is primarily due to the fact that current passes are written recursively, and very deep traversals of the graph can run into Stack overflow problems.

This issue has been discussed a number of times, notably here:

and in the v0.7 roadmap

In order to support this use case, I’ve done a POC for a non-recursive Graph Visitor and Rewriter in this PR: https://github.com/apache/incubator-tvm/pull/4886

There are two main classes: PostOrderGraphVisitor, which inherits from ExprFunctor and provides a non-recursive Graph Traversal with an optional visitation function (similar to the current PostOrderVisit function), and ExprRewriter, which inherits from ExprMutator but overrides the recursive behavior, instead relying on PostOrderGraphVisitor to provide iteration ordering.

As a Proof of Concept, I change the ForwardRewriter passes to use the new classes. The main advantage of this design is that it maintains much of the same API as the current ExprMutator/ExprVisitor, so updating passes to use the new infrastructure is relatively simple.

There is some argument to made that a more direct break between the current behavior and the the non-recursive behavior might be in order, so I’d like to get input from the broader community.

Thanks!

cc @tqchen @jroesch

3 Likes

This should work for TVM IR as well, once we unify IR infra, right?

After introducing the SeqStmt in the TVM low-level IR, we likely won’t have stack overflow problem anymore, because the recursive visit only happens when we introduce a new scope, but not a long sequence. The maximum stack depth will be the maximum depth of the scope, plus the depth of the local expression.

1 Like

Thanks @mbrookhart for a great proposal. It would be great to think about how can we get a concise and clean interface for the non-recursive visitor. Here is a possible strawman’s proposal

// Visit elements of input recursively, apply rewriter to each of the 
Expr PostOrderRewrite(Expr input, 
                      std::function<Expr(Expr post_order, const Expr& orig_pre_order)> rewriter);

Expr PostOrderRewrite(Expr input, 
                      std::function<Expr(Expr post_order)> rewriter);


Expr PreOrderRewrite(Expr input, 
                      std::function<Expr(Expr pre_order)> rewriter);

// Example
// implement TempExprRealize, recursively realize all temp expressions in input.
Expr RealizeTempExpr(Expr input) {
   return PreOrderRewrite(input, [](Expr pre_order) {
      if (auto* node = pre_order.as<TempExprNode>()) {
        // rewriter will call visit on the realized nodes here.
        return node->Realize();
      } else {
        return pre_order;
      }
   });
}

Another thing to note is that this approach only resolves the problem of rewriting dataflows, we still need to come up with a general recursive visitor that addresses control scopes, and invokes dataflow rewriter in each of the scope.

Does this proposal allow us (users of TVM) to get a graph representation of Relay program easily?

I have been struggling to traverse Relay program and get graph representation of Relay program.

Hi All,

Getting back to this idea, some more conversations between myself, @tqchen, and @jroesch have discovered a few edge cases where purely non-recursive passes will be much more difficult to implement.

To match some of the ideas about Block Normal Form, we’d like to propose iteration infrastructure that expands the call graph in dataflow regions, but still allows the user to recurse for control flow/scoping/let binding/etc. We’re calling this mixed-mode traversal, and there is a POC available here that mimics the current ExprVisitor/ExprMutator API. Please take a look at the comments in that PR for an explanation of Call Expansion and Mixed-mode traversal.

This POC provides a relatively simple API for converting current passes to support non-recursive behavior, therefore we think it’s valuable to implement in the short term.

Longer term, we’d like to propose this API:

For most passes, we suggest using this API, where traversal is done non-recursively by the PostOrderRewrite function; rewriting logic is handled in a non-traversing ExprRewriter:

    /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
     *
     *  PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
     * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
     * PostOrderRewrite provides the original node and the node with altered inputs for use by the
     * ExprRewriter.
     */
    Expr PostOrderRewrite(const Expr& input, ExprRewriter* rewriter);
    
    /*! \brief A non-iterating Expression Rewriter
     *
     *  ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
     *  The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
     * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
     * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
     * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
     * graph rewriting.
     */
    class ExprRewriter : private ExprFunctor<Expr(const Expr&, const Expr&)> {
     public:
      /*! \brief Rewrite a node given the orginal form and the form with modified inputs
       *
       *  Uses ExprFunctor for vtable access.
       *
       *  Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
       * able to rewrite the op only with data about the original node `pre` and the same node with
       * modified inputs `post` and should not recurse.
       */
      Rewrite(const Expr& pre, const Expr& post) {
        return this->VisitExpr(pre, post);
      }
      virtual Expr Rewrite_(const TupleNode* pre, const TupleNode* post);
      virtual Expr Rewrite_(const CallNode* pre, const CallNode* post);
      virtual Expr Rewrite_(const TupleGetItemNode* pre, const TupleGetItemNode* post);
      ...
      ...
     private:
      Expr VisitExpr(const Expr& pre, const Expr& post) final {
        return ExprFunctor::VisitExpr(pre, post);
      };
      Expr VisitExpr_(const TupleNode* pre, const TupleNode* post) final {
        return Rewrite_(pre, post)
      };  
      Expr VisitExpr_(const CallNode* pre, const CallNode* post) final {
        return Rewrite_(pre, post)
      };  
      Expr VisitExpr_(const TupleGetItemNode* pre,
                      const TupleGetItemNode* post) final {
        return Rewrite_(pre, post)
      };  
      ... 
      ... 
    }

For passes that need to optimize across scope or perform other more complex behavior, we will provide this API, with similar naming for user Rewrite_ functions, where users can take more direct control over recursion, but still process dataflow regions non-recursively.

    /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
     *
     *  Scope Mutator provides the same mixed-mode traversal as DataflowMutator, but provides the
     * Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior.
     */
    class ScopeMutator : public ::tvm::relay::ExprMutator {
     public:
      Expr VisitExpr(const Expr& expr) final;
      Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
      Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
      Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
      /*! 
       *  Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
       * able to rewrite the op only with data about the original node `pre` and the same node with
       * modified inputs `post` and should not recurse.
       */
      virtual Expr Rewrite_(const TupleNode* pre, const TupleNode* post);
      virtual Expr Rewrite_(const CallNode* pre, const CallNode* post);
      virtual Expr Rewrite_(const TupleGetItemNode* pre, const TupleGetItemNode* post);
    
     protected:
      /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
       * changed inputs.
       */
      template <typename T>
      Expr Rewrite(const T* op) {
        const T* post = ExprMutator::VisitExpr_(op).as<T>();
        return Rewrite_(op, post);
      }
    
      virtual void VisitLeaf(const Expr& expr);
      virtual bool CheckVisited(const Expr& expr);
    };

Thank you for your thoughts!

cc @yzhliu @MarisaKirisame @junrushao @haichen @janimesh @mbaret @jonso @zhiics would be great to get feedbacks given this change is going to affect all relay passes

this api look pretty reasonable to me. is there a reason pre and post must be of the same type though? I would assume post should just be a const Expr&

I don’t see any major difficulties in adapting the passes I’ve worked with to use this. I think pre and post have to be the same type because post will be the node after its inputs have been rewritten. I presume you can still rewrite the type of the node itself when you return an Expr from Rewrite_.

I don’t see any major issues either. This is a welcome addition. Thanks for introducing this.

The API design looks good to me. I wonder if this non-recursive pass can support customized template, as there’re many passes that have different return type and probably different input types as well.

I like the proposal. Overall looks good to me.

Yep! This is the envisioned use case, return any Expr you like, the provided post will always be the same node type with rewritten inputs.

@haichen Could you point me at a pass that uses different return types? I’m assuming this is something that skips using ExprMutator and uses ExprFunctor directly? I’d like to think through ways to support this.

Yes, that’s right, for example, the AD pass, partial eval, ToANF. It’ll be more flexible if we can support customized functor. But it’s okay that we can do that in a followup PR.

writing trampoline by hand is probably the way to go for those passes.

hmm… after a bit of thinking, doesnt this pass force visit to all children, even though that dont necessarily need it? will making the right hand side a lazy value make more sense?

Hmm, that’s a good point, the other thing I’m currently working on is the Pattern Language, and I’m noticing some similar kinds of inflexibilities with that. Will think about how to make this a little more general.

Okay, I’ve implemented these APIs and converted a couple of passes to use them, please take a look and review. https://github.com/apache/incubator-tvm/pull/4886

@haichen @MarisaKirisame I’m noticing that the passes with these edge cases all seem to be doing more complicated things than the simple post-order DFS targeted here. I’m still looking at them, but I think that should be a second PR?

Thanks for all of the feedback!

1 Like