Problem
As discussed here, there are many implementations for the exact same core computation. To avoid writing a ton of patterns to satisfy all of these implementations, it would be nice to have a way to “always match” certain operators in a pattern, even if they aren’t present in the user-defined pattern.
In other words, I propose a way to match the core computation, without having to match all of the extraneous data-mutation operators like reshape
, transpose
, cast
, etc…
Example Use-Case
I am trying to match a transformer, and am using the HuggingFace Transformer exported to ONNX. The start of the transformer, from the ONNX perspective, looks like MatMul -> Add
. However, after importing to TVM, the ONNX frontend does a bunch of data mutation. This is to account for broadcasting and the fact that TVM does matrix multiplication as (m,k) x (n,k)
, where ONNX does matrix multiplication as (m,k) x (k,n)
. This means that the Relay expression becomes Reshape -> Reshape -> Transpose -> MatMul -> Reshape -> Add
.
Different frontends may handle the reshape / transposes differently, but they will all do the core computation of MatMul -> Add
. To avoid writing a ton of patterns, I would like an option to always consider these reshape and transpose operators as a match and just skip over them. In this case, the core MatMul -> Add
pattern will always match, and any operators in between will be merged into the composite function.
Solution
I am only just starting to look at the implementation of MergeComposite
, so I would look to @mbaret and @comaniac for suggestions. However, one solution could be the following:
The MergeComposite pass will take an extra parameter: a list of Relay ops to “always match”. The ExtractPattern
function will continue moving the pattern “in lockstep” with the root. When the root and pattern differ, and the root’s op is in the “always match” list, the root node will move, while the pattern node will stay in place. The ExtractPattern
function will only return an empty expression if the root and pattern nodes differ and the root node’s op is not in the “always match” list.
Please comment with other ideas and implementation suggestions