[Relay][Pass] Add submodule extraction

We want to add submodule extraction. In the future this pass may allow for building post-fusion autotuning support, which will enable more accurate tuning runs. With this pass, we’ll have better task de-duplication support for model zoos + better separation of layers for easier parallel autotuning. Also it would be nice to tune a single layer and an entire model in the exact same way, and this pass helps achieve that.

Here is the current proposed api. We invoke SimplifyInference and FuseOps on the given module before attempting to extract:

  • extract_submodules(mod: IRModule) -> List[IRModule]
  • extract_hashed_submodules(mod: tvm.IRModule) -> Dict[int, tvm.IRModule], where int indicates a structural hash

Should this be analysis or transform pass?

  • A0 analysis - we can treat this as a “read” of existing structures and construct a collection of IRModule.
  • A1 transform - most passes here tend to preserve semantic equivalence, which this pass doesn’t uphold. In this case, instead of constructing a collection of IRModule, we may construct a single IRModule containing all the functions.

How should we name this for clarity?

  • B0 extract_primitive_tasks
  • B1 extract_submodules
  • B2 extract_subgraphs

I plan to rewrite the pass in C++ following discussion. Thanks for your comments!

1 Like

cc @zhiics @jroesch @mbaret would love to get some of your thoughts

I may miss some context, but this pass is more like outlining functions to the module scope, right? May I ask what is your need to have some multiple IRModules? You want to have pipelined execution for them?

@zhiics Yes, some other applications are better task de-duplication support for model zoos, as well as better separation of layers for easier parallel autotuning. Also it would be nice to tune a single layer and an entire model in the exact same way, and this pass helps achieve that.

Edit: Added above context to original post

I tend to prefer A0 and B1 in terms of pass placement, but what are your suggestions and/or reservations?

@anwang sorry I still do not understand the motivation of this pass. After reading the pass, I know that it collects functions and creates a module for each function, but it seems like just a utility pass for convenience. If your goal is to auto-tune fused ops (please correct my if I misunderstood it) then this pass is insufficient, because you will at least need to define a new tuning task and a way to generate a tuning space.

On the other hand, if this pass is the first step of supporting primitive function auto-tuning, I would suggest this pass outputting a list of functions instead of modules. The reason is that current AutoTVM creates a single op module only when it wants to measure its latency. In order to maintain a unified flow for measurement, it would be better to let AutoTVM builder create modules.

BTW, this pass also breaks the current IRModule->IRModule transformation. Is it possible to just lift the functions to the module scope and work from there?

Also, shouldn’t it be A1 since we vastly change the IRModule?

@comaniac Sorry for being unclear. Yep, your second point was my original intent: the pass is intended to be just the first step in primitive function auto-tuning. To fully accomplish this, it will be necessary to add to AutoTVM the ability to benchmark against full functions instead of just tunable ops, and that is future work after this initial pass.

@zhiics I agree that one IRModule -> many IRModules breaks the current idea of a transform pass. Instead we can use @comaniac’s idea and leave this as an analysis pass that only gathers basic information? i.e. the pass itself will only gather all functions in a read-only manner.

So the new API I propose is an ExtractFunctions pass with the following helpers that will copy the input IRModule and only invoke SimplifyInference and FuseOps on the copy (if it’s not possible to copy let me know):

  • extract_functions(mod: IRModule) -> List[Function]
  • extract_hashed_functions(mod: IRModule) -> Dict[int, Function], where int indicates a structural hash

Does that look ok to you? It’s an extremely simple pass, but I hope moving forward that this will help enable tuning post-fusion rather than pre-fusion. Thanks for discussing, and let me know what other reservations you have! :slight_smile:

For the sake of being thorough, I want to raise the possibility of ExtractFunctions being a transformation pass with IRModule -> (IRModule of functions) e.g.

body:[conv, add, conv, conv, add] 
-> functions:[fn {conv, add}, fn{conv}, fn {conv, add}]

@anwang It makes sense to me. We can make it as an analysis pass where you gather the functions. In order to keep the original module unchanged, I agree we need to return new copies of the functions. After that, you can mutate these functions and perform tuning on each them.

Thanks for the clarification. Then I agree this pass is a necessary first step for primitive function tuning support.

btw, I’m super interested in this feature. It would be the best if you could file an RFC for a complete proposal so that people who are also interested in this project could comment.

@anwang Could you elaborate more on the difference between autotuning single op vs autotuning the entire primitive function?

@haichen Yes, autotuning the entire primitive function should help enable tuning for fused ops.

Currently autotvm only tunes for a conv when it’s given a {conv, add} workflow, and then fusion only happens after tuning; so we may have non-ideal tuning output for the actual fused scenario.

The desired behavior is that tuning should take into account fusion for the most useful performance measurements when tuning. Does that answer your question?

Okay, I now understand why you want to do this. I feel the proposal in this RFC is only useful to a larger-scope RFC. I’m not against this RFC, but I think it could benefit from more justification. For example, you could run some benchmark and demonstrate the performance difference between tuning a single op vs tuning a fused op.

I don’t think this pass is necessary for parallel autotuning. After we extract the autotvm tasks, we already can perform parallel tuning.

I feel the pass itself is more useful from an information extraction point of view besides the autotvm related tasks.

One particular application for example could be extract primitives sub-tasks that can be used for further analysis such as count number of occurrence of primitive task patterns.

So while it is not necessary to bake it as a pre-req as autotvm, making it an analysis pass makes sense, if we have the right interface(e.g. return a collection of functions, or just put them in a IRModule)

Thanks for the discussion! For the sake of being thorough (though this is a very minor change), I’m concluding this thread with the final design decisions.

I have made this an analysis ModulePass that gathers the fused primitive functions in a result IRModule’s functions.

Please add further suggestions to the PR itself.