[pre-RFC] TVM Explorer Infrastructure

Hi community:

We would like to propose a Pre-RFC to enhance the debugging ability of TVM. This Pre-RFC will describe the details of how do we enhance the current functionality. On the other hand we will mention about what visualization can we get from TVM Explorer, which is another web GUI project based on this RFC.

Summary

The goal of this RFC is to extend the capability of tracing source information between different IRs for the debugging uses. Three features get benefit from this change as following:

  • Map Layer name from ML frontend IR to Relay IR.
  • Record source expressions to the transformed ones during pass optimizations.
  • Queue the snapshots of schedule after changes made by primitives

These changes provide users a clear backtrace of an IR in CLI text format. Furthermore, paired with our on-going project TVM Explorer, a colorful and convenient GUI can improve the user experience even better. We will demonstrate the use cases of TVM Explorer with examples in the following sections.

Motivation

We aim to ease debugging process by enhancing and creating the features to carry source information. TVM performs numbers of transformations to optimize and deploy a ML frontend IR to a targeted device. However, currently modules which record source information between IRs are not fully used. It makes users hard to trace the source of a transformed IR. Usually an investigation to source code should be done so as to understand details of a transformation.

We provide the following enhancements to mitigate users’ effort by recording source information between IR and schedules of op implementation:

  1. Frontend span filler: Fill the layer name to Relay IR during the frontend conversion.
  2. Pass source information builder: Construct SequentialSpan from Span and SIBuilder to handle source information for both Relay IR and TIR.
  3. Schedule/Stage visualization enhancement: Record and propagate op’s schedule snapshots with primitives applied in regular build flow.

After these modifications, user can obtain the source information simply by a glance or via debugger.

Finally, inspired by Compiler Explorer, we build a web-GUI, TVM Explorer for TVM. Based on the infrastructures above, TVM Explorer provides a batter user experience when comparing IRs or analyzing schedules (the code base of TVM Explorer is maintained in another git repository and not included in this RFC).

TVM Explorer


Guide-level explanation

TVM infrastructures

Frontend span filler

Based on the ExprMutator, we implement set_span to recursively fill the source information to Relay IR during the op conversion. We could obtain the Relay IR with span even in an one-to-many conversion. Take Pack op from TF for example, it inserts multiple expand_dims during conversion:

# implement of pack TF conversion
def _pack():
    def _impl(inputs, attr, params, mod):
        axis = int(attr["axis"])
        inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
        return _op.concatenate(inputs_reshaped, axis)

    return _impl

# After convert an op from frontend
ret = self.convert_map[op_code_str](op)
ret = set_span(ret, frontend_layer_name)

'''
The result after set_span of a pack op conversion
def @main (%input: Tensor[(?, ?, 3, 1), float32]) {
    %0 = shape_of(%input, dtype="int32") /* Shape */;
    %1 = strided_slice(%0, …) /* strided_slice */;
    %2 = squeeze(%1) /* strided_slice */;
}
======>
def @main (%input: Tensor[(?, ?, 3, 1), float32]) {
    %0 = shape_of(%input, dtype="int32") /* Shape */;
    %1 = strided_slice(%0, …) /* strided_slice */;
    %2 = squeeze(%1) /* strided_slice */;
    %3 = expand_dims(%2, axis=0) /* stack */;
    %4 = expand_dims(3, axis=0) /* stack */;
    %5 = expand_dims(3, axis=0) /* stack */;
    %6 = (%3, %4, %5) /* stack */;
    %7 = concatenate(%6) /* stack */;
}
'''

Pass source information builder

To manage the span propagation in passes, we extend SequentialSpan from Span, and create a new class SIBuilder. First, we construct a container class, SequentialSpan to carry a set of source spans in its member variable for those many-to-n (n>=1) conversion, which is common in transformations between passes:

// C++
SequentialSpan new_span{expr_1->span, expr_2->span}
# Python
relay.SequentialSpan([expr_1, expr_2])

Take the IfNode condition in FoldConstant pass for example. When the condition is a constant, FoldConstant extracts the expression of the triggered path as the result. We create a SequentialSpan to keep the existent span from the selected branch and the span from discarded If expression.

Expr VisitExpr_(const IfNode* if_node) final {
  If new_if = Downcast<If>(ExprMutator::VisitExpr_(if_node));
  if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(new_if->cond)) {
    Expr ret;
    if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) {
      ret = new_if->true_branch;
    } else {
      ret = new_if->false_branch;
    }
    ret->span = SequentialSpan({ret->span, new_if->span});
    return ret;
  }
  return std::move(new_if);
}

On the other hand, SIBuilder aims to ease the developers’ workload when filling span in the pass transformation. Based on our experiences when filling span to existing passes, we provide two functionalities in SIBuilder. First, RecursivelyFillSpan provides an easy way to automatically fill up source span to those conversions which result in multiple expressions. Given a source span, RecursivelyFillSpan applies DFS traversal from “start_expression” and fill the source span until it encounters any of those given inputs.

SIBuilder si_builder(source_span);
sibuilder->RecursivelyFillSpan(start_expression, {inputs_of_the_first_new_generated_expr});

An use case of RecursivelyFillSpan is SimplifyInference. This pass simplifies certain operators during inference. Take BatchNorm for example, SimplifyInference unpacks the Call of BatchNorm and its TupleGetItem indexed at 0 to several simplified expressions. In this case we can invoke RecursivelyFillSpan to fill span to those new generated expressions once for all.

Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
                            Expr moving_var, Type tdata, Span span) {
  auto ttype = tdata.as<TensorTypeNode>();
  ICHECK(ttype);

  const auto param = attrs.as<BatchNormAttrs>();
  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
  Expr var_add_eps = Add(moving_var, epsilon);
  Expr sqrt_var = Sqrt(var_add_eps);
  Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);
  //...

  Expr out = Multiply(data, scale);
  out = Add(out, shift);

  SIBuilder si_builder(span);
  si_builder.RecursivelyFillSpan(/* entry */ out,
                                 /* inputs */ {data, gamma, beta, moving_mean, moving_var});
  return out;
}

Second, SIBuilder provides a constructor to collect a continuous sequence of source spans. Starts from entry, it puts the span of an Expr to its array member variable, and continues the traversal until hits the inputs. Finally, invoke CreateSpan from the created SIBuilder instance to obtain the source span.

SIBuilder si_builder(entry_expr, {inputs});
new_span = si_builder.CreateSpan();

This constructor works properly in SimplifyExpr pass. A pattern of SimplifyExpr is SimplifyReshape, one of its patterns is an expression followed by two consecutive rehsapes or contrib_reverse_reshapes. In this case we can use the constructor of SIBuilder above to obtain all source spans of the matched pattern.

class SimplifyReshape : public DFPatternRewrite {
 public:
  SimplifyReshape() {
    x_ = IsWildcard();
    auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
    auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
    pattern_ = reshape1({reshape2({x_})});
  }

  Expr Callback(const Expr& pre, const Expr& post,
                const Map<DFPattern, Array<Expr>>& node_map) const override {
    //...
    if (const_shape) {
      auto x = node_map[x_][0];
      auto ret = MakeReshape(x, newshape);

      SIBuilder si_builder(/* entry */ node_map[pattern_][0], /* inputs */ {x});
      ret->span = si_builder.CreateSpan();

      return ret;
    //...
};

Based on the classes above, we have filled span to all relay passes in the build flow.

Schedule/Stage Visualization Enhancement

Tensor Expressions are scheduled with primitives, it becomes complicated quickly with the increasing number of applied primitives. Although TEDD(Tensor Expression Debug Display) already provides a mechanism to visualize different kinds of schedule diagrams(Schedule Tree, Itervar Relationship and Dataflow). The resulting information still seems hard to recognize the effect of each applied primitive.

We propose a change to record the snapshot of schedule after each new primitive is applied by introducing some modifications to the interface of Schedule/Stage class. In order to inspect the schedules created inside TVM build flow, new APIs will also be added.

By doing so, we can leverage TEDD to display a sequential schedule diagrams, the followings are the snippet of driving code and the corresponding result:

# load TFLite model
tflite_model_buf = open('mobilenet.tflite', "rb").read()
model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
input_shape = {'input': (1, 224, 224, 3)}
mod, params = relay.frontend.from_tflite(model, input_shape)

# invoke build process
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, 'llvm', params=params)

# (new API) get schedule from 17th node in TVM graph
sch = lib.build_module.get_node_schedule(17)
# (new API) get schedule record (snapshots of schedule)
schs = sch.schedule_record

# the second to last schedule
ori_dot = tedd.viz_schedule_tree(schs[-2].normalize(), dot_file_path="ori.dot")
# the last schedule with all the optimization strategies
cmp_dot = tedd.viz_schedule_tree(schs[-1].normalize(), dot_file_path="cmp.dot")

Ori.png



We could see the effect of applied primitive “compute_at”, which moved the computation scope of “conv2d” inside the outter loop of “cast” stage:

Cmp.png


TVM Explorer preview

Inspired by Compiler Explorer, TVM Explorer is our on-going project which is a web-GUI to investigate TVM behaviors. Based on the infrastructures above, TVM Explorer achieves the following goals:

  1. Frontend span filling: Link and interact converted Relay IR with Frontend IR.

Frontend mapping based on Span (Netron)



Frontend mapping based on Span (Pretty-print)

  1. Source information builder: Find the source expressions between transformations of passes

Pass conversion (Unpack batchnorm)

  1. Schedule/Stage visualization enhancement TVM Explorer provides mechanism to visualize the computation graph generated from GraphExecutor. With the proposed changes, the data structure of Schedule will be kept inside each graph node where users are able to visualize the implementation details:

Computation graph.png


Click into the graph nodes to see further comparisons, like highlighting the difference between two schedules after applying a primitive:

Schedule comparison.png


Reference-level explanation

Frontend span filling

This feature had been introduced previously in PR-9723, but was reverted because the unexpected duplicated expressions problem in PR-10072. We fix the issue in PR-10072 and propose a modified version with the following differences:

  1. Fix the problem of duplicated expressions in PyTorch conversion:
    Previously we did not invoke set_span in each required condition, and did not handle tuple/list type properly during PyTorch conversion. It resulted in duplicated expressions were generated. After the investigation, we insert set_span to each required place to avoid duplication.
  2. Remove the suffix in source information:
    Because in the stage of pass transformation, suffix seems to mess up the source information of an expression. We remove the _PART_N for concision.
  3. Redesign PyTorch source information:
    Since a representative PyTorch source information might not exist, we introduce a method to reconstruct a new one.
  4. Use environment variable TVM_SPANFILLING to disable/enable span filling:
    If span filling is not required, set the environment variable with export TVM_SPANFILLING=0 to disable the procedure.
def set_span(sym, span):
    """Set up the sapn of relay expression(s) while converting OP"""
 
    class SpanFiller(ExprMutator):
        """SpanFiller"""
 
   return SpanFiller(span).fill(sym) if _should_fill_span() else sym

The following is the details of set_span. The constructor now accepts both string and span format as its source information. The function fill accepts types in the whitelist to prevent unexpected symbol. The function visit stop traversal deeper once the flow hits an expression with span. In the dispatched visit function like visit_call, SpanFiller reconstructs and returns a new expression with the given span.

class SpanFiller(ExprMutator):
    """SpanFiller"""
    def __init__(self, span):
        ExprMutator.__init__(self)
        if isinstance(span, tvm.relay.Span):
            self._span = span
        elif isinstance(span, str):
            self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0)
        else:
            assert False, f"unsupported span type: {type(span)}"

    def visit(self, expr):
        if hasattr(expr, "span") and expr.span:
            return expr

    def visit_call(self, call):
        new_args = [self.visit(arg) for arg in call.args]
        return _expr.Call(call.op, new_args, call.attrs, call.type_args, self._span)
    #...

    def fill(self, sym):
        if isinstance(sym, _expr.TupleWrapper):
            return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
        elif isinstance(sym, _expr.RelayExpr):
            return self.visit(sym)
        elif isinstance(sym, list):
            assert all(
                isinstance(expr, _expr.TupleGetItem) for expr in sym
            ), f"unexpected relay expressions in {sym}"
            return [self.visit(expr) for expr in sym]
        elif isinstance(sym, tuple):
            assert all(
                isinstance(expr, _expr.RelayExpr) for expr in sym
            ), f"unexpected relay expressions in {sym}"
            return tuple(self.visit(expr) for expr in sym)

        assert False, f"unsupported type {type(sym)}"

Pass source information builder

  • SequentialSpan:
    Inherits from Span, SequentialSpan can accept and put a sequence of Span to its tvm::Array. For those many-to-n (n>=1) transformations, SequentialSpan is a good container to carry their source. When comparing the equalness between two SequentialSpan, simply fall back to the equalness of each span to obtain the result iteratively.

    UML of SequentialSpan

class SequentialSpanNode : public SpanNode {
 public:
  /*! \brief A list of spans that used to compose a sequential span. */
  tvm::Array<Span> spans;
  static constexpr const char* _type_key = "SequentialSpan";
  bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const;
  TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode);
};
 
class SequentialSpan : public Span {
 public:
  TVM_DLL SequentialSpan(Array<Span> spans);
  TVM_DLL SequentialSpan(std::initializer_list<Span> init);
};
  • SIBuilder:
    SIBuilder provides two functionalities for both Relay/TIR pass transformations. One is recursively filling spans to those new generated expressions without span. Another is collecting source spans from a contiguous sequence of expressions. The following UML demonstrates the overview of SIBuilder:

    UML of SIBuilder

class SIBuilder {
 public:
  explicit SIBuilder(const Span& span = Span());
 
  /*!
   * \brief Create SIBuilder via a subgraph,
   *        will construct span based on the exprs falls in the subgraph
   *
   * \param entry Entry expr for subgraph
   * \param inputs End exprs for subgraph
   */
  template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
  explicit SIBuilder(const T& entry, const tvm::Array<T>& inputs = {});
  explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<PrimExpr>& inputs = {});
  explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<tir::Stmt>& inputs = {});
 
  ~SIBuilder();
 
  SIBuilder(const SIBuilder&) = delete;
  SIBuilder& operator=(const SIBuilder&) = delete;
 
  /*!
   * \brief create new source info based on span_buffer_.
   *
   * \return The span.
   */
  Span CreateSpan() const;
 
  /*!
   * \brief Recursively fill subgraphs exprs' span
   *
   * \param entry Entry expr for subgraph
   * \param inputs End exprs for subgraph
   */
  template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
  void RecursivelyFillSpan(const T& entry, const std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
 
  void RecursivelyFillSpan(const tir::Stmt& entry, const std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
 
  void RecursivelyFillSpan(const tir::Stmt& entry, const std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
 
 private:
  struct Impl;
  std::unique_ptr<Impl> impl_;
 
  std::unique_ptr<Impl> CreateImpl(const Span& span);
};

Start from the RecursivelyFillSpan we will describe how to fill a given span to those new generated expressions. Take the RelayRecursivelyFill for Relay type as an example, it inherits from ExprMutator to traverse the given expressions. If the visited expression is one of the inputs, it stops the traversal. Otherwise RecursivelyFillSpan dispatches to the corresponding type, sets up the span, and traverses deeper.

class RelayRecursivelyFill : public relay::ExprMutator {
 public:
  RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {})
      : span_(span), inputs_(inputs) {}
 
  void Fill(const relay::Expr& entry);
 
  relay::Expr VisitExpr(const relay::Expr& expr) final;
  relay::Expr VisitExpr_(const relay::CallNode* call_node) final;
  // other types...
 
 private:
  const Span& span_;
  const RelayExprSet& inputs_;
};

relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) {
  //...
  if (inputs_.find(expr) != inputs_.end()) {
    return expr;
  }
  //...
}

relay::Expr RelayRecursivelyFill::VisitExpr_(const relay::CallNode* call_node) {
  call_node->span = span_;
  return relay::ExprMutator::VisitExpr_(call_node);
}

On the other hand, the constructor of SIBuilder accepts an entry and a set of inputs to collect all of the source information. The core functionality for Relay is implemented by the class RelayCollapse, which inherits from ExprVisitor. Visitor function Collapse acts in a similar way to RecursivelyFill, it starts from the entry, put the span of an expression to its array member variable, and continues the traversal until hits the inputs. The collected spans can be produced by invoking the CreateSpan function from the SIBuilder instance.

class RelayCollapse : public relay::ExprVisitor {
 public:
  RelayCollapse(const RelayExprSet& inputs = {}) : inputs_(inputs) {}

  Span Collapse(const relay::Expr& entry);

  void VisitExpr(const relay::Expr& expr) final;

 private:
  tvm::Array<Span> spans_;
  const RelayExprSet& inputs_;
};

void RelayCollapse::VisitExpr(const relay::Expr& expr) {
  // ...
  if (expr->span.defined()) {
    spans_.push_back(expr->span);
  }
 
  if (inputs_.find(expr) != inputs_.end()) {
    visit_counter_.emplace(expr.get(), 1);
    return;
  }
  // ...
}

Span RelayCollapse::Collapse(const relay::Expr& entry) {
  VisitExpr(entry);
  return SequentialSpan(spans_);
}

Finally, SIbuilder can be disabled by the setting of ir.enable_si_builder in the config of PassContext:

TVM_REGISTER_PASS_CONFIG_OPTION("ir.enable_si_builder", Bool);

Schedule/Stage visualization enhancement

  • Schedule Record:
    To inspect the series of Schedule transformations, new member variables are introduced to store the objects.

    // ${TVM}/include/tvm/te/schedule.h
    class ScheduleNode : public Object {
     public:
      ...
      /*!
      * \brief list of all schedules during primitives applied to stages.
      */
      Array<Schedule> schedule_record;
      /*!
      * \brief Flag to keep schedule record or not.
      */
      bool keep_schedule_record;
      ...
    };
    

    For every Stage inside a Schedule, it needs to know what current Schedule is and appends the snapshot of Schedule after a primitive applied.

    // ${TVM}/include/tvm/te/schedule.h
    class Stage : public ObjectRef {
     public:
      ...
      explicit Stage(Operation op, Schedule& sch);
      ...
      /*!
      * \brief Not functional currently.
      */
      TVM_DLL void EnterWithScope();
      /*!
      * \brief Store current schedule after primitive being applied.
      */
      TVM_DLL void ExitWithScope();
      ...
    };
    

    Semantic “With” is used here:

    // ${TVM}/src/te/schedule/schedule_lang.cc
    void Schedule::EnterWithScope() {}
    void Schedule::ExitWithScope() {
      ScheduleNode* sch_node = operator->();
      if (sch_node->keep_schedule_record) {
        sch_node->schedule_record.push_back(copy());
      }
    }
    

    All primitives could leverage the mechanism above to record the status of Schedule, take “parallel” primitive as an example:

    Stage& Stage::parallel(IterVar var) {  // NOLINT(*)
    + With<Schedule> sch_scope(operator->()->attach_sch);
      SetAttrIterType(operator->(), var, kParallelized);
      return *this;
    }
    

    The effect can be explained in the following snippet:

    def schedule_record_with_gemm():
      M, K, N = 1024, 1024, 1024
      k = te.reduce_axis((0, K), "k")
      A = te.placeholder((M, K), name="A")
      B = te.placeholder((K, N), name="B")
      C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")
      s = te.create_schedule(C.op)
      
      # currently there are no other applied primitives
      # size of schedule record is expected to be 1 (vanilla schedule)
      assert len(s.schedule_record) == 1
      
      # let's apply sequential optimization primitives
      block_size, factor = 32, 8
      # tile -> split + split + reorder
      mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], block_size, block_size)
      ko, ki = s[C].split(k, factor=factor)
      s[C].reorder(mo, ko, no, mi, ki, ni)
      s[C].vectorize(ni)
      s[C].parallel(mo)
      
      # the primitives inside schedule record are (primitive type and its store order):
      # vanilla(1), split(2), split(3), reorder(4), split(5), reorder(6), vectorize(7), parallel(8)
      assert len(s.schedule_record) == 8
    
  • Schedule Propagation:
    By investigating the TVM build flow (Relay to a target executable), the Schedule instance will be stored in the attribute of CallNode inside MakeLoweredCall and retrieved in GraphExecutorCodegen process (i.e. schedules will finally be kept in corresponding graph nodes)

    Callstack of build flow.png

    Finally, a series of APIs will be created accordingly for user to access the Schedule instance from Relay build module.

Drawbacks

  • Process extra debug information would cause longer compilation time:
    • The cost of span filling when converting deep learning model to Relay expression.
    • The cost of handling span propagation when applying passes to IR.
  • Store extra debug information would cause larger memory consumption in compile time:
    • The cost of keeping string information coming from source name of deep learning model.
    • The cost of saving snapshots of schedule when primitives applied.
  • Currently only the source_name member inside Span is used to achieve source mapping mechanism, how to leverage other members like line or col?
  • Engage more unit tests to validate the proposed changes.

The collection of extra debug information can be controlled by environment variable to minimized the effect on performance.

Rationale and alternatives

  • The proposed changes are based on existent features, we hope these enhancements could make TVM more comprehensive:
    • Source mapping mechanisms to help user quickly identify the relationship of different IR.
    • Leverage existent tool like TEDD to visualize the effect of every schedule primitive.

Prior art

Frontend span filling

A fundamental set_span function has been introduced to TVM repo in TensorFlow frontend. The new implementation we proposed can resolve the following problems:

  1. Support TensorFlow 1 only.
  2. Only Call Expr is dealt.
  3. Not able to handle one-to-many conversion.

After investigations, we can support multiple frontends and resolve the problem 1. Based on the set_span derived from ExprMutator, we can properly handle the problem 2 and 3.

Pass source information builder

The SequentialSpan extends the capability of Span so as to handle those multiple source transformations. The SIBuilder is a new helper class for the developers when they are tagging span in a pass.

Schedule/Stage visualization enhancement

This functionality extends original design with some interface changes, so as to the existent tool, TEDD, which will also be modified slightly. With this new feature, users could have better understanding on op implementation by visualizing the effect of primitives.

Unresolved questions

  • The standard of the line number definition for different IR.
    It’s intuitive to do the source mapping by matching the literal information from different frontends, which also provide more readability. However, the span information could gradually get messy when expressions are merged together during optimizations. The situation might be mitigated by mapping IRs via line numbers.
  • Proper way to generate unique source information for ONNX frontend.
    The source information relies on user’s annotation when making models. There remains some research to figure out a robust way of generating unique identifier once the source information is missing.
  • Concat suffix to the span to indicate the computation output expression.
    Currently we have no strategy on highlighting expressions with special functions(e.g. input/output/parameter…). It would be helpful for user to categorize expressions at first glance.

Future possibilities

  1. This RFC would be a good reference for other frontends to add their span filling mechanism.
  2. Extend the span propagation scope from TIR to backend IR.
  3. Perhaps it is possible to have text format pretty print for Schedule data structure rather than using TEDD only.
  4. Currently only the passes used by regular build flow are able to handle span propagation. We hope every pass could be well supported in the future to provide better debuggability.

Upstream milestone

We plan to have following PRs with corresponding test cases:

  • Span filling for following popular frontends
    • TFLite (PR)
    • PyTorch (PR)
    • TensorFlow (PR)
    • ONNX (PR)
  • Span propagation of Relay passes (passes within regular build flow will be covered)
    • Source information builder (PR)
    • Series of passes (PRs)
  • Schedule/Stage visualization enhancement
    • Schedule record and API changes (PR)
    • TEDD modification (PR)
  • Span propagation of TIR passes (WIP, passes within regular build flow will be covered)
20 Likes

cc @ulfhanebutte @dchickles this is a bit related to tracing IR nodes through compilation, but perhaps not quite what you guys needed (full disclosure: I still need to finish reading it).

1 Like

Thank you for the review!

Just a gentle ping: @areusch @ulfhanebutte @dchickles May we know further comments for this topic?

GREAT JOB!!! I am interested in this. Do you hack netron to do this interactive mapping?

Hi @FrozenGene,

Yes we did. We add a few lines to deal with node name of NodeProto for the frontends like Pytorch and ONNX in our internal marinating Netron.

Yet core functionality of interaction between source model and complier is implemented in our inhouse web application. Here is a GIF for your reference. :slight_smile:

interaction

4 Likes

What I could say is just AMAZING! Really appreciate your job! :+1: @chunit

2 Likes

What I could say is just AMAZING! Really appreciate your job! :+1: @chunit

It’s really glad to see you like it! :smiley:
Zack, Hao-Wei and me all try really hard to make this tool be helpful when developing TVM.

Would you mind to give us some comments to make it better? Or is there anything unclear? We can try to explain it more precisely.

Thanks again!

Because I don’t play your app, but I have some thought about it. I think we could dump intermediate output for every layer as you have done interactive mapping. i.e. when we hover on one specific layer on netron, we could dump output of this layer (from input to it). we could also show this layer’s execution time on specific hardware (cpu / gpu or whatever), we could also show tir or generated source code like CUDA/LLVM IR/OpenCL etc like compiler explorer could show compiler result (assembly).

For this pre-rfc, I am happy to spend more time in reading it and maybe could give more comments.

Hi @chunit ,

Nice work! I noticed you talked about propagating the name of the layer in the frontend to the relay op. Would it be possible to also propagate the name of the parameters of the layer, so that we can trace it back from the relay?

Example:

Convolutional layer weights can alter its layout, or be merged with other constants. It would be nice to be able to trace back from the relay operator which is the original weight name in the frontend layer.

Apologies for the delay in reply–I just needed to find some time to sit down and read the RFC all the way through. This is great work @chunit @haowhsu-quic and Zack ! I’m supportive of moving this forwards.

In An enabling framework for int8 quantization, we discussed how to effectively track frontend layers throughout the compiler. It seems like you guys have taken the same approach we discussed there–leveraging the graph edges (e.g. tensors) as the “stable” part of the graph and labelling that Relay ops between them as belonging to the same frontend layers (e.g. RecursivelyFillSpan). Right now, this needs to be done per-pass, but I wonder if we could get away with doing this once at the end of compilation if we also attach references to the frontend layer (or post-import variable) to each Relay Var.

It seems like by annotating Var we might be able to add this information.

One issue is that once we move outside of Relay (e.g. in AOT flow), it’s harder to fill span information back up through the compiler since the layer variables have changed. I’m curious if you guys tried to apply this to any TIR-based fusion?

Lastly, any idea how much additional memory this takes or performance impact?

Nice to see it ^^. Wondering if it possible to have a play around website/github?

Thank you very much for the suggestion of TVM Explorer! Here are some works we have done for the functionality you just mentioned. :slight_smile:

we could dump output of this layer (from input to it)…

In our TVM Explorer we do have a functionality, called as Executor to communicate with device via RPC and obtain the inferenced result from a targeted relay expression. It is not connected with Netron but it is a good idea to think about how to connect it with Netron.

show tir or generated source code like CUDA/LLVM IR/OpenCL…

We aim to support this mapping in the GIF you saw too. Currently the corresponding TIR/LLVM IR results can be obtained in the Executor above too. Yet we are still working on the Span (source information) propagation after pass transformation. Because Span propagation between passes need to be done per-pass. So far we have done the span propagation for those necessary Relay passes in the build flow based on our infrastructures. About the TIR part we are still working on it. Here is a table shows how many passes have been filled span for your reference.

RelayPass TIRPass Not yet done TIR Pass
AlterOpLayout LowerInitBlock BF16Legalize
AutoSchedulerLayoutRewrite LowerIntrin CombineContextCall
CanonicalizeCast MakePackedAPI CompactBufferAllocation
CanonicalizeOps MakeUnpackedAPI ConvertBlocksToOpaque
CombineParallelBatchMatmul NarrowDataType FlattenBuffer
CombineParallelConv2D PlanAndUpdateBufferAllocationLocation HoistIfThenElse
CombineParallelDense RemoveNoOp InferFragment
DefuseOps RewriteUnsafeSelect InjectDoubleBuffer
DynamicToStatic SplitHostDevice InjectPrefetch
EliminateCommonSubexpr InjectVirtualThread
EtaExpand InstrumentBoundCheckers
FastMath LoopPartition
FoldConstant LowerCustomDatatypes
FoldScaleAxis LowerDeviceStorageAccessInfo
FuseOps LowerMatchBuffer
InferType LowerTVMBuiltin
Inline LowerThreadAllreduce
SplitArgs LowerWarpMemory
LabelOps MergeDynamicSharedMemoryAllocations
Legalize Simplify
RemoveUnusedFunctions StorageFlatten
SimplifyExpr StorageRewrite
SimplifyInference TextureFlatten
ToBasicBlockNormalForm ThreadSync
relay::qnn::transform::Legalize UnifyThreadBinding
UnrollLoop
VectorizeLoop
VerifyMemory

For this pre-rfc, I am happy to spend more time in reading it and maybe could give more comments.

Take your time please, we will wait for it! :smiley:

1 Like

Hi @fPecc,

Would it be possible to also propagate the name of the parameters of the layer, so that we can trace it back from the relay?

Should be YES. :smiley:

Here is an example from the TFLite model for you. As you can see, the names of weights in conv2d ops, and values of bias_add ops are attached in their input part. (Note that we use the “bind_prarms_by_name” in this example.)

Although it is possible, it requires some more investigations in each frontend and modify a bit for Var type printer. We need more time to confirm it.

2 Likes

No worry. :smiley: Thank you very much for helping us! If you don’t mind I would like to submit more materials for you, and ask some qusetions about the Var thing you just mentioned.

Right now, this needs to be done per-pass

Yes, we did attach span to per-pass based on the “sequentialSpan” and “SIBuilder”. It is a time consuming task. Currently we have done the following passes. All these passes are invoked during the build flow. We would try to complete the rest of passes.

RelayPass TIRPass Not yet done TIRPass
AlterOpLayout LowerInitBlock BF16Legalize
AutoSchedulerLayoutRewrite LowerIntrin CombineContextCall
CanonicalizeCast MakePackedAPI CompactBufferAllocation
CanonicalizeOps MakeUnpackedAPI ConvertBlocksToOpaque
CombineParallelBatchMatmul NarrowDataType FlattenBuffer
CombineParallelConv2D PlanAndUpdateBufferAllocationLocation HoistIfThenElse
CombineParallelDense RemoveNoOp InferFragment
DefuseOps RewriteUnsafeSelect InjectDoubleBuffer
DynamicToStatic SplitHostDevice InjectPrefetch
EliminateCommonSubexpr InjectVirtualThread
EtaExpand InstrumentBoundCheckers
FastMath LoopPartition
FoldConstant LowerCustomDatatypes
FoldScaleAxis LowerDeviceStorageAccessInfo
FuseOps LowerMatchBuffer
InferType LowerTVMBuiltin
Inline LowerThreadAllreduce
SplitArgs LowerWarpMemory
LabelOps MergeDynamicSharedMemoryAllocations
Legalize Simplify
RemoveUnusedFunctions StorageFlatten
SimplifyExpr StorageRewrite
SimplifyInference TextureFlatten
ToBasicBlockNormalForm ThreadSync
relay::qnn::transform::Legalize UnifyThreadBinding
UnrollLoop
VectorizeLoop
VerifyMemory

I wonder if we could get away with doing this once at the end of compilation if we also attach references to the frontend layer (or post-import variable) to each Relay Var.

If it could be done at the end of compilation it would be quite convenient! Sorry that I am not really following this. May I have your explanation again please? Like, may I have an example for

  1. What it looks like about attaching references to the frontend layer?
  2. What should be attached to Relay Var?

It seems like by annotating Var we might be able to add this information.

About this part I would like to have some more explanation. Except the Var or Params, this problem also happens in those one-to-many conversion. Here I would like to take the Pack OP from TF as example again. Currently we fill the layer name to the converted IR like this:

def @main (%input: Tensor[(?, ?, 3, 1), float32]) {
    %0 = shape_of(%input, dtype="int32") /* Shape */;
    %1 = strided_slice(%0, …) /* strided_slice */;
    %2 = squeeze(%1) /* strided_slice */;
    # the Pack Op conversion start from here
    %3 = expand_dims(%2, axis=0) /* stack */;
    %4 = expand_dims(3, axis=0) /* stack */;
    %5 = expand_dims(3, axis=0) /* stack */;
    %6 = (%3, %4, %5) /* stack */;
    %7 = concatenate(%6) /* stack */;
}

And here is the result from former patch:

def @main (%input: Tensor[(?, ?, 3, 1), float32]) {
    %0 = shape_of(%input, dtype="int32") /* Shape /;
    %1 = strided_slice(%0, begin=[0], end=[1], strides=[1], axes=None) / strided_slice_PART_0 /;
    %2 = squeeze(%1) / strided_slice /;
    %3 = expand_dims(%2, axis=0) / stack_PART_0 /;
    %4 = expand_dims(3, axis=0) / stack_PART_1 /;
    %5 = expand_dims(3, axis=0) / stack_PART_2 /;
    %6 = (%3, %4, %5) / stack_PART_3 /;
    %7 = concatenate(%6) / stack /;
}

In the former patch we can indicate computation output of Pack Op immediately because we do not add suffix for it. Now we remove it because we notice that “_part_” suffix is really annoying and misleading after the pass transformations.

The drawback of current version is we cannot tell which one is the computation output because they all look the same. Perhaps we can do something like the following example. But we are still seeking for a better solution.

def @main (%input: Tensor[(?, ?, 3, 1), float32]) {
    %0 = shape_of(%input, dtype="int32") /* Shape */;
    %1 = strided_slice(%0, …) /* strided_slice */;
    %2 = squeeze(%1) /* strided_slice */;
    # the Pack Op conversion start from here
    %3 = expand_dims(%2, axis=0) /* stack */;
    %4 = expand_dims(3, axis=0) /* stack */;
    %5 = expand_dims(3, axis=0) /* stack */;
    %6 = (%3, %4, %5) /* stack */;
    %7 = concatenate(%6) /* stack_OUTPUT */;
}

it’s harder to fill span information back up through the compiler since the layer variables have changed. I’m curious if you guys tried to apply this to any TIR-based fusion?

We are still working on the TIR pass as shown in the list above. Besides, we haven’t done the propagation between Relay → TE or TIR. Because that’s also a tough part we encounter. :frowning: Things are not too complicated in the Relay environment, but it becomes harder when we go down to lower IR like TE and TIR. Currently we still rely on the layer name. Yet we are thinking perhaps using the row & column number could be more robust and more indicative.

If we have a precise definition of the line number information of an IRModule, we could at least have a better mapping relationship before and after “a pass”.

Lastly, any idea how much additional memory this takes or performance impact?

Yes, take the mobilenet_v1_2018_08_02 for example, here is the profiling result:

RunTime performance

function Without span filling With span filling with span filling & schedule_record
relay.frontend.from_tflite() 133174.0 us 176468.0 us(↑32.51%) 177774.0 us(↑33.49%)
relay.build() 7480367.0 us 7558526.0 us(↑1.045%) 7580165.0 us(↑1.334%)

Memory usage

function Without span filling With span filling with span filling & schedule_record
relay.frontend.from_tflite() 26.105 MiB 26.203 MiB(↑0.375%) 26.211 MiB(↑0.406%)
relay.build() 147.762 MiB 148.148 MiB(↑0.261%) 148.418 MiB(↑0.443%)

We also provide optionst to disable span filling and shcedule recording if users don’t need them.

2 Likes

Hey Zack!

We are asking the legal supporting. You know, it takes time, haha. I would update the news once I get it.

Wow, such excellent work! Always want some interactive debugging feature like this when playing around with TVM. You guys make it come true! Looking forward to the release :clap:

2 Likes

Cool, thanks for the explanations!

The Var thing I’m discussing here is not exactly a simple tweak to this proposal–it’s probably significant enough lift that it would deserve its own RFC. So just to clarify–I’m not necessarily asking you to change your approach. However, I did want to raise this question to a) build more support for the idea, b) see if it is potentially easier to pursue than adding SIBuilder support to the remaining passes, and c) think through whether it’d be easier to maintain in the long run.

The basic idea is like so: consider your one-to-many example conversion. A common challenge we face in TVM is determining which Relay Expr correspond to one another before and after a pass. To choose a concrete example, suppose we introduce a pass which outlines part of a function (suppose it outlines Pack from your previous example). Before executing the pass, suppose we start from your example:

Now suppose we run the outliner, and arrive at:

def @outlined_pack(%i1) {
  %0 = expand_dims(%i1, axis=0) /* stack */;
  %1 = expand_dims(3, axis=0) /* stack */;
  %2 = expand_dims(3, axis=0) /* stack */;
  %3 = (%0, %1, %2) /* stack */;
  %4 = concatenate(%3) /* stack */;
  %4
}

def @main (%input: Tensor[(?, ?, 3, 1), float32]) {
    %0 = shape_of(%input, dtype="int32") /* Shape */;
    %1 = strided_slice(%0, …) /* strided_slice */;
    %2 = squeeze(%1) /* strided_slice */;
    # the Pack Op conversion start from here
    %3 = @outlined_pack(%2);
    %3
}

Now the question here is: after running the pass, does a new Relay var exist which contains %7? The answer is yes: it’s %7. In order to make this outline, an e.g. ExprMutator needed to capture the subgraph that contains %3 through %7, then replace it with a call to the new function and store the result in %3. This pass knows that %3 == %7, and (similarly to how Span information is filled here) when defining %3, could include some type of backreference to %7. This could even just be included as a Map:

using VarMap = Map<Var,Var>;  // keys are originally-imported Var, values are the equivalent now inside f.
Function f = mod.GetFunction("main");
f->GetAttr<VarMap>("var_map");

This approach could be taken all the way back to the original import (e.g. or there could be an additional map from input framework layer to Relay var).

SIBuilder takes as input a set of Expr which bound the subgraph. Since most Relay programs are transformed in A-Normal form, the VarMap could substitute for these Expr. This won’t work for all optimizations, but I think for a decently large class of them, we could automatically apply SIBuilder by walking VarMap and applying Spans to the subgraphs with endpoints in VarMap. The advantage of this technique is that it could also be done with TIR with the same approach.

I think you’d need to assert that the Relay or TIR graph could be partitioned along VarMap for this to work–so I’m not saying it would work for all transforms. But I do think it would work for many. It’s also worth noting that this is a best-effort tracking scheme–it’s possible through e.g. operator fusion that some Vars could simply be eliminated. In these cases, the VarMap may not contain all Var from the original model

Thanks for providing this data! It seems reasonable as part of running with a debug option at least!

1 Like

Thank you for this detailed explanation! We digest the content and try to apply this concept to an existing pass. There are still many implementation details we have not figured out. Yet the following is how we illustrate the var mechanism should be like. Please kindly help us if we misunderstand anything. :smiley:

Goal

Implement a pass to construct a graph. The graph is a tracing map to record the transformation before and after a pass.

What the map should looks like

Personally I would prefer the key are the f, the new equivalent now, and value are the original var. It should be more convienent for us to trace back to the source. So it should be like:

Map<Var,Var>
// Keys are the equivalent now inside f
// Values are originally-imported Var.

Because after a sequence of pass transformations, we would have a final IRModule. Select a certain expression in the final IRModule[“main”], we can trace back to the source. If we use the the originally-imported Var as Key. Perhaps we have to iterate through all the map to find the resulted Var after transformations.

How to invoke

Considering the function GetPassPrefix in “src/relay/backend/utils.cc” we insert a pass OutLiner between passes:

//...
pass_seqs.push_back(transform::SimplifyInference());
pass_seqs.push_back(OutLiner);
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(OutLiner);
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(OutLiner);
//...

Process looks like

Take the Relay Pass, SimplifyInference for example, it unpacks certain Calls like batch norm op. The following image is a part of result after the transformation of SimplifyInference pass in our Explorer.

It takes the batch_norm call and its tupleGeItem as source exprs and unpacks them to a set of basic operations.

Now the following is the process once we introduce the OutLiner pass:

Back to the IR pretty print, we would start from IR[“main”] here:

def main(...) {
  %0 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
  %1 = nn.batch_norm(%0,...) /* si=torch.batch_norm_8 */;
  %2 = %1.0 /* si=torch.batch_norm_8 */;
}

After the SimplifyInference the IR[“main”] becomes:

def main(...) {
  %0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %4 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
  %5 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %6 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %7 = multiply(%6, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %8 = add(%7, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %9 = multiply(%4, %5) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %10 = expand_dims(%8, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %11 = add(%9, %10) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}

Now it is the time to invoke OutLiner. It generates another global function, outlined_bn_0.

def main(...) {
  %0 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
  %1 = @outlined_bn_0(%0,...)
}

def outlined_bn_0(%i1...) {
  %0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %4 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %5 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %6 = multiply(%5, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %7 = add(%6, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %8 = multiply(%i1, %4) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %9 = expand_dims(%7, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %10 = add(%8, %9) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}

#Perhaps we would need the original main as reference
def main_before_SimplifyInference_0(){
  #...
}

On the same time, we maintain our the tracing map like this (Key and value should be a Var, yet I am not pretty sure show to exress them in a Var form).

# key: transformed result
# values: original things
map = {
    hash(outlined_bn_0): {%1-batch_norm, %2-%1.0}
}

Using the graph constructed by tracing map, we should be able to trace an IR back to its very original form. Perhaps the functionality of OutLiner might be Implemented based on StructuralEqual. But we haven’t come up a good idea for this currently. Still, if this OutLiner is Implementalbe, it will be really convenient. :smiley:

Questions

In here we come up some questions about this strategy:

  1. What IRModule would be used once the OutLiner is invoked? Should be IR1 but not the IR2, right?
  • IR1
def main(...) {
  %0 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
  %1 = @outlined_bn_0(%0,...)
}

def outlined_bn_0(%i1...) {
  %0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %4 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %5 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %6 = multiply(%5, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %7 = add(%6, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %8 = multiply(%i1, %4) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %9 = expand_dims(%7, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %10 = add(%8, %9) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}
  • IR2
def main(...) {
  %0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %4 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
  %5 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %6 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %7 = multiply(%6, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %8 = add(%7, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %9 = multiply(%4, %5) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %10 = expand_dims(%8, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
  %11 = add(%9, %10) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}
  1. If we choose the IR1, and continue the transformations of the rest of passes. It might end in a nested form. The readiblity should become very terrible. Perhaps a unpack pass for outlined_fn is requried too, right?

  2. Still about the nested form, if we use the nested form like IR1, many pattern matching things may need to rewrite, because now they need to check the outlined_fn in the graph. The complexity of Implement a pass might increase.

Thank you for reading such long post. it feels great that we can try to figure a better way to maintain the source information. :smiley:

1 Like

Sure thing–I think you broadly understand my proposal. Let me clarify some things:

It could be a pass or it could be some other way (e.g. modify Expr constructor). The tracing map is the goal, though.

That seems reasonable, so long as the model is always in A-Normal Form. If it isn’t then we may need Map<Expr,Expr> here. I think this was stated earlier, just reiterating.

This could also be handled by PassManager, but yeah that’s the right idea, if we took a pass-based approach here. I’ll sketch some ideas I have below.

This is pretty close to my suggestion, but let me tweak it slightly. The goal here would be to map a Var in the final Relay or TIR representation to a Var that represents it in the original program (assume the original program is expressed in A-Normal Form, and suppose we allow for trivial TupleGetItem Expr in this map, so %0.2 is a valid value). After running this pass, the map here would then look like:

# key: transformed Var
# values: Expr representing the original value
# keys not present where no mapping exists
map = {
    %input: %input,
    %model.conv1.weight: %model.conv1.weight,
    ...  # same for the rest of the inputs (not as trivial if the keys were instead TIR Var)
    %4: %0,  # I think I understood this transform properly, I think the reordering is due to A-Normal Form conversion after the rewrite, but that in the final program, %4 doesn't depend on %0, %1, %2, %3
    %1: %2  # or %1.0, if that was the only such representation of this.
}

Given this map, the Expr that could be used with SIBuilder then are just the keys of the map.

I think you could then implement a fairly simple algorithm to apply SIBuilder:

  1. Invert the variable map (swap keys and values).
  2. Step through the original program, and for each Relay Expr:
    1. Identify the inputs and outputs (this is akin to building a connectivity graph in the final program, but we sort of get it for free from the original)
    2. Lookup those values in the resultant program using the Map
    3. Create SIBuilder with span equal to the Relay Expr. Run RecursivelyFillSpan(outputs, inputs).

I haven’t thought about this enough, but I think this could run into some limitations maybe around loops and control flow, particularly if we apply the same approach to TIR. I’d need to think about it a bit further.

Building the map

As for how to build this map, here are some thoughts:

  1. Modify Expr() constructor to take another arg Expr orig_expr. Modify all passes to pass orig_expr.
  2. Change ExprMutator and kin to accept such a Map (or get it out of an IRModule attr). When Mutate_ returns a Node different than the one passed-in, modify the map.
  3. Attempt to derive this from an analysis pass as you mentioned.

I think #1 or #2 may not cover all cases here, and some passes may also need to be updated. The reason I’m raising this here is it seems like equivalent work to track relationships between Vars, and if it was possible to get away with using that work to label Spans, we might be able to do this once. Finally, I’m thinking about how to apply SIBuilder to LowerTE, which is what generates TIR for Relay, and how to preserve that information when doing MetaSchedule-style transforms in a TensorIR world. It seems a bit more straightforward to propagate this Var relation rather than the Span info. Var tracking can also be useful in AOT for:

  • Identifying which TIR Vars represent which Relay Expr (e.g. implementing GraphExecutorDebug)
  • Profiling layers run in TIR, using those Vars as a hint for where a layer’s compute starts and stops.

Anyway, here I am curious to hear your thoughts on whether you think we could leverage this for Span annotations. The work here is helpful for the project either way, so I think we could also merge this now and, if we can improve the maintanability via Var tracking, we could make that improvement as a follow-on.

cc’ing some other folks who have been thinking about this at octo: @anwang @AndrewZhaoLuo @mbaret @mehrdadh

1 Like

This is awesome and super helpful work. Can’t wait to use it.

1 Like