Since the pattern language has been merged, we are planning to rewrite MergeComposite
pass for BYOC with it.
Brief Introduction to Merge Composite Pass
BYOC is designed to help 3rd party codegen and accelerators integrate into TVM. Since 3rd party accelerator may have a specialized and powerful instruction to deal with multiple operators (e.g., conv2d+add+relu), the first pass in BYOC is to fuse those operators to a separate function and annotate the function with the instruction name. As a result, the external codegen can easily replace the entire function with a single instruction.
Reasons to Use Pattern Language
The current implementation of MergeComposite
pass accepts a user-specified map such as {'my-inst': graph}
, where graph
is a small Relay program used as a pattern. More detail use cases can be found in the unit test:
On the other hand, using a Relay program as a pattern causes lots of limitations. Basically all motivations specified in the pattern language ([RFC][Relay] Program Matching for Relay Pt 1: A pattern language) are applicable.
Proposal and POC
Accordingly, we propose to use pattern language for MergeComposite
. Itās simple, robust, and more general. For example, Iāve written a POC of MergeComposite
in pattern language. As can be seen, we only need less than 100 lines to achieve the same functionality.
The above POC includes two solutions.
Pattern.Partition
Thiis is a builtin functionality of pattern language that partitions the matched subgraph to a separate function. However, this builtin function doesnāt allow users to specify function attributes, so we cannot add Composite = inst
.
PatternCallback
Another solution uses callback functions to manually create composite functions. The problem with this solution is that we need an extra visit to mutate the Relay graph in order to create the function.
Discussion
While we prefer the first solution that uses pattern.partition()
, we need to figure out how to add the composite attribute.
S.1: Enhance Partition
A straightforward approach is enhancing pattern.partition
to accept more configurations.
for pattern, inst in patterns:
out = pattern.partition(out, {'Composite': inst})
S.2: Post-Processing
If we do not want to change the partition, we could use a set know the new generated functions after each partition, and add the attribute to them:
for pattern, inst in patterns:
curr_funcs = get_func_as_set(out)
out = pattern.partition(out)
for func in out.functions:
if func not in curr_funcs:
func.with_attr('Composite', inst)
Any other thoughts and comments are welcome.