TuplePattern for any number of inputs to the tuple

Hello,

Is it possible to define a pattern for a tuple with any number of inputs? When using tvm.relay.dataflow_pattern.is_tuple() function you define a pattern with a set number of fields. My main goal is to define a pattern where a tuple of any size is fed into a concat function.

You mean, the number of elements is fixed when you actually match them, right? You are looking for “one or more” pattern (regex equivalent of “+”)? I don’t think we have that yet.

I also have a similar matching problem, where I want to match against unrolled for loop. I want to match the loop body a certain number of times. But for my case, since I know the unroll factor when I actually instantiate the pattern, it is easy to construct a big pattern consisting of the repeated body pattern on the fly.

cc @mbrookhart

No, we don’t have this right now.

Multi-output patterns and recursive patterns have been on my (very low priority) backlog for a while, mostly because no-one has complained about them. It doesn’t sound like you need quite that full functionality, though.

If all you need is “this is a tuple with some number of inputs and I don’t care how many inputs their are or what they are” it would be pretty easy to create a Tuple with None fields instead of an Array of fields, and then add some handling to the passes to ignore the pattern fields if they aren’t defined. I can do that today, if that would work for you.

I think it would be pretty straight forward to define that pattern as concat(wildcard()), though, because the only thing we can concatenate is a tuple, and if we’re throwing away the information about the fields, this is the same as the None fields case, I think.

That would be great. If it helps my current workaround for capturing tuples of up to 20 inputs is:

def _generate_tuple_pattern():
    """Generates a pattern for tuples that can can take in 1 to 20 inputs"""
    pattern = is_tuple([wildcard()])
    for i in range(2, 20):
        pattern = pattern | is_tuple([wildcard() for j in range(i)])
    return pattern

I am using these patterns with relay.transform.MergeComposite. I do not think concat(wildcard()) ropes the input tuple into the pattern. This is the result when executing the code below:

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import wildcard, is_op

data = []
for i in range(2):
    data.append(relay.var("in_{}".format(i), shape=(1, 4, 4, 3), dtype='float32'))
z = relay.concatenate(data, axis=1)
concat = relay.Function(relay.analysis.free_vars(z), z)
mod = tvm.IRModule.from_expr(concat)


concat_pattern = is_op('concatenate')(wildcard())
mod = relay.transform.MergeComposite([('concat_pattern', concat_pattern)])(mod)
print(mod)

Output:

def @main(%in_0: Tensor[(1, 4, 4, 3), float32], %in_1: Tensor[(1, 4, 4, 3), float32]) -> Tensor[(1, 8, 4, 3), float32] {
  %0 = (%in_0, %in_1);
  %1 = fn (%FunctionVar_0_0: (Tensor[(1, 4, 4, 3), float32], Tensor[(1, 4, 4, 3), float32]), PartitionedFromPattern="concatenate_", Composite="concat_pattern", Primitive=1) -> Tensor[(1, 8, 4, 3), float32] {
    concatenate(%FunctionVar_0_0, axis=1) /* ty=Tensor[(1, 8, 4, 3), float32] */
  };
  %1(%0) /* ty=Tensor[(1, 8, 4, 3), float32] */
}

Ah, I see, the wildcard becomes a variable for partitioning. That makes sense.

Give me a few hours, I’ll try to get this done around meetings this morning.

1 Like

Sorry for the delay, meetings ate up my morning. I’ve implemented my initial idea, and while it passes the matching test, it doesn’t pass the partitioning test, just like is_op("concatenate")(wildcard())

This is happening because in parititioning, I’m assuming that patterns without inputs are themselves inputs to the subgraph, and so I treat them as inputs to the partitioned function.

Changing that assumption to bring those matches into the partitioned function is going to break a lot of other things, so I don’t think that’s feasible in the short term.

Your workaround might be the best thing in the short term, I’d need to think about how to do support for a regex style “+” in a tuple, I’m not sure what the knock-on effects would be.

1 Like

That’s all right. That input assumption is good as is. Maybe the solution would look like is_op("concatenate")(is_tuple(wildcard())) where a wildcard() input to tuple captures any type of input into the tuple or maybe is_tuple(wildcard()) looks for any tuple operation

Hmm, I will keep stewing. That’s an interesting idea, but I think it’s going to run into type issues on the C++ side.

This works, but I don’t love it. Thoughts?

1 Like

It looks fine. Maybe instead of None we just leave the field blank?

is_tuple(None) is_tuple() .

To me None seems to imply an empty tuple. It also would be consistent with is_constant() and wildcard()

So, back to thinking about this this morning. @jeff815 @slyubomirsky

We basically have three places we might want to match any number of inputs: Call, Function, and Tuple.

@masahi commended a Many Pattern, but since we aren’t trying to match The IR itself, we’re trying to match a variable-size container in the IR, the lock-step recursive matching the the pattern matcher breaks with the introduction of this as a Pattern. We went on do discover what he actually needs is recursive matching, which is a bigger problem.

I can do what I did in that PR for Call and Function, it just requries copy/pasting this section of the code for everything we’d want to support.

      auto tuple = node->ref_.as<TuplePatternNode>();
      if (tuple && !tuple->fields.defined()) {
        if (node_map.count(node->ref_)) {
          auto matches = node_map[node->ref_];
          for (auto match : matches) {
            for (auto input : match.as<TupleNode>()->fields) {
              make_input(input);
            }
          }
        }

which is an ugly design, but doable.

The other option, I think, is to take this kind of ultility, which you both seem to be using:

def _generate_tuple_pattern():
    """Generates a pattern for tuples that can can take in 1 to 20 inputs"""
    pattern = is_tuple([wildcard()])
    for i in range(2, 20):
        pattern = pattern | is_tuple([wildcard() for j in range(i)])
    return pattern

And codify it a little to generate all three nodes. That, unfortunately, does incur some additional complexity on the frontend and some additional runtime in matching the alternative patterns.

What do you think? I think I’m leaning towards doing the copy/pasting, it’s a little gross in the backend, but not the end of the world.

Is there a way to clean up the “ugliness”? A reduction in complexity and runtime is more valuable i think.

I think at the end of the day it’s a case of runtime type inference and specialization. :slight_smile: I’m not sure how to clean that up any more.

Ok. Im fine with the copy/pasting route then.

For what it’s worth (as I mentioned privately), I don’t think having to use the pattern (()|(a,)|(a,b)|(a,b,c)|...) you described is a serious shortcoming, so I favor whatever least complicates the implementation.

So, it sounds like Jeff wants the simplification in the frontend, and Steven wants the simplification in the backend. @masahi, do you have a tie breaking vote?

I’m ok with copy pasting, and hopefully we can come up with a clean implementation using some C++ tricks.

So…this got lost, but I spent some time today catching up:

Sorry for the delay

2 Likes