How to match the pattern of a function in Relay?

def @main(%data: Tensor[(1, 112, 112, 32), float32]) -> Tensor[(1, 112, 112, 64), float32] {
  %3 = fn (%p0: Tensor[(1, 112, 112, 32), float32], %p1: Tensor[(3, 3, 32, 1), float32], %p2: Tensor[(1, 1, 1, 32), float32], %p3: Tensor[(1, 1, 1, 32), float32], Primitive=1) -> Tensor[(1, 112, 112, 32), float32] {
    %0 = nn.conv2d(%p0, %p1, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 112, 112, 32), float32] */;
    %1 = multiply(%0, %p2) /* ty=Tensor[(1, 112, 112, 32), float32] */;
    %2 = add(%1, %p3) /* ty=Tensor[(1, 112, 112, 32), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 112, 112, 32), float32] */
  };
  %4 = %3(%data, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 32, 1), float32] */ /* ty=Tensor[(3, 3, 32, 1), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 1, 1, 32), float32] */ /* ty=Tensor[(1, 1, 1, 32), float32] */, meta[relay.Constant][2] /* ty=Tensor[(1, 1, 1, 32), float32] */ /* ty=Tensor[(1, 1, 1, 32), float32] */) /* ty=Tensor[(1, 112, 112, 32), float32] */;
  %8 = fn (%p01: Tensor[(1, 112, 112, 32), float32], %p11: Tensor[(1, 1, 32, 64), float32], %p21: Tensor[(1, 1, 1, 64), float32], %p31: Tensor[(1, 1, 1, 64), float32], Primitive=1) -> Tensor[(1, 112, 112, 64), float32] {
    %5 = nn.conv2d(%p01, %p11, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 112, 112, 64), float32] */;
    %6 = multiply(%5, %p21) /* ty=Tensor[(1, 112, 112, 64), float32] */;
    %7 = add(%6, %p31) /* ty=Tensor[(1, 112, 112, 64), float32] */;
    nn.relu(%7) /* ty=Tensor[(1, 112, 112, 64), float32] */
  };
  %8(%4, meta[relay.Constant][3] /* ty=Tensor[(1, 1, 32, 64), float32] */ /* ty=Tensor[(1, 1, 32, 64), float32] */, meta[relay.Constant][4] /* ty=Tensor[(1, 1, 1, 64), float32] */ /* ty=Tensor[(1, 1, 1, 64), float32] */, meta[relay.Constant][5] /* ty=Tensor[(1, 1, 1, 64), float32] */ /* ty=Tensor[(1, 1, 1, 64), float32] */) /* ty=Tensor[(1, 112, 112, 64), float32] */
}

Hello, I’m just start learning how to use Relay’s pattern matching. I wonder if there’s a way I can match a function pattern (together with its internal pattern, e.g. conv+mul+add+relu), just like the two shown in the above example? Thanks in advance!

Check this out: https://tvm.apache.org/docs/langref/relay_pattern.html

Thanks! I had checked that out, but seems it doesn’t show a way to match a function. In my case conv+mul+add+relu is already wrapped into a function, so I failed to match them directly. One example in the tutorial related to function matching uses function attr, but it looks like the function I have above has a None attr.

Any thoughts?

Hmm, I’m not quite sure if pattern matcher will go into Relay functions for matching. I’ll check it later but maybe @mbrookhart could comment.

Meanwhile, maybe we can make a new FunctionPattern that matches function nodes.

It will definitely go inside the function to match patterns, but you’re right, we don’t have a Function Pattern right now, we should probably add one.

This seems to be a function created by the FuseOps pass. Typically we’d do pattern rewriting/partitioning before that, maybe there’s a simpler way to get what you’re looking for?

2 Likes

+1 for this feature - I have some use-cases where it would be valuable to match composite functions.

I’m happy to add it, but it will be a couple of days before I can get to it. Any one else interested in adding the node and some matching tests?

I have to agree with @mbrookhart with his statement that pattern rewriting/partitioning comes before FuseOps. Could you provide more information as to why this would be useful?

I’d like to be able to rewrite a composite function into a single op. Obviously can be done with an ExprMutator but it’d be more convenient if the pattern language could do it.

Thanks @mbrookhart @mbaret for the valuable information! Here is a short summary of this issue.

  • The problem OP has should be resolved by simply matching and partitioning patterns before FuseOps.
  • In addition to that, it would still be better to have FunctionPattern that matches and rewrites Relay functions.

I’m interested in adding the node but I may not have bandwidth in these days either. @mbrookhart I could ping you when I get a chance to work on it to avoid duplicated work :slight_smile:

@comaniac sounds good, we’ll see who gets there first :slight_smile:

I also +1 this feature.

It seems that what you want to do is very similar to a previous post I had Relay pattern replacing with custom relay ops?

The biggest fear I have of introducing a new Relay operator (which I guess is what would eventually happen if we follow this path) was that it would have to work with all the Relay transformation passes. The example in the TVM documentation https://tvm.apache.org/docs/dev/relay_add_op.html seems to trivial and not sufficient to answer this question.

Any guidance you could give would be gladly received :slight_smile:

Thank you all for the discussion!

When a module is built, do the pattern matching and the MergeComposite pass currently happen before FuseOps or after? Is it possible to make it configurable, as in some use cases (like mine) the matching and merging takes the results of some previous passes, e.g. SimplifyInference, etc, as the inputs?

It should be configurable but depends on your use case. You could first figure out when FuseOps is invoked (and by which API) in your script, and we can see if that makes sense to run pattern matching and merge composite beforehand, Otherwise, you may still need FunctionPattern.

@comaniac OK I see. I think FunctionPattern would solve it all for me here. Glad to see the team has a plan on this!

@mbrookhart

I have tried using this to match

pattern = is_op('nn.conv2d')(wildcard(), wildcard())
pattern = is_op('multiply')(pattern, wildcard())
pattern = is_op('add')(pattern, wildcard())
tuple_get_item_node = TupleGetItemPattern(pattern, 0)
pattern = is_op('nn.relu')(tuple_get_item_node)

but it returns false. Looks like it doesn’t go inside the function. I also tried disabling FuseOps and insert MergeComposite after Optimize() has finished, but it throws an error saying something like "add is detected and will be removed by future passes."

Just to see if there’s any luck here. Is there a way to get it work with the current code base?

Why your pattern has a tuple node? It seems to me that it tries to match

    %0 = nn.conv2d(%p0, %p1, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 112, 112, 32), float32] */;
    %1 = multiply(%0, %p2);
    %2 = add(%1, %p3);
    %3 = %2.0; /* This looks incorrect? */
    nn.relu(%3)

Yeah… You’re right. I copied it from conv+bn+relu pattern and forgot to delete the line for tuple. But removing it still doesn’t make it work though.

The way the API is written, the matcher doesn’t traverse your input graph, it will return false unless the exact node you supply matches the pattern.

The partition and rewriting functions, however, will traverse the relay expression, go into functions, and find the match inside your function. If you partition instead of matching, you should see the extraction of your pattern into a function.

Thanks! I’ll try that after I fix up some of the bugs I’m having.

@mbrookhart So I tried the way you suggested and I’m able to rewrite the pattern inside a function. I wonder if it’s also possible to partition and rewrite a pattern across multiple functions? I suspect if this would need the support of the potential FunctionPattern.