TuplePattern for any number of inputs to the tuple

Hello,

Is it possible to define a pattern for a tuple with any number of inputs? When using tvm.relay.dataflow_pattern.is_tuple() function you define a pattern with a set number of fields. My main goal is to define a pattern where a tuple of any size is fed into a concat function.

You mean, the number of elements is fixed when you actually match them, right? You are looking for ā€œone or moreā€ pattern (regex equivalent of ā€œ+ā€)? I donā€™t think we have that yet.

I also have a similar matching problem, where I want to match against unrolled for loop. I want to match the loop body a certain number of times. But for my case, since I know the unroll factor when I actually instantiate the pattern, it is easy to construct a big pattern consisting of the repeated body pattern on the fly.

cc @mbrookhart

No, we donā€™t have this right now.

Multi-output patterns and recursive patterns have been on my (very low priority) backlog for a while, mostly because no-one has complained about them. It doesnā€™t sound like you need quite that full functionality, though.

If all you need is ā€œthis is a tuple with some number of inputs and I donā€™t care how many inputs their are or what they areā€ it would be pretty easy to create a Tuple with None fields instead of an Array of fields, and then add some handling to the passes to ignore the pattern fields if they arenā€™t defined. I can do that today, if that would work for you.

I think it would be pretty straight forward to define that pattern as concat(wildcard()), though, because the only thing we can concatenate is a tuple, and if weā€™re throwing away the information about the fields, this is the same as the None fields case, I think.

That would be great. If it helps my current workaround for capturing tuples of up to 20 inputs is:

def _generate_tuple_pattern():
    """Generates a pattern for tuples that can can take in 1 to 20 inputs"""
    pattern = is_tuple([wildcard()])
    for i in range(2, 20):
        pattern = pattern | is_tuple([wildcard() for j in range(i)])
    return pattern

I am using these patterns with relay.transform.MergeComposite. I do not think concat(wildcard()) ropes the input tuple into the pattern. This is the result when executing the code below:

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import wildcard, is_op

data = []
for i in range(2):
    data.append(relay.var("in_{}".format(i), shape=(1, 4, 4, 3), dtype='float32'))
z = relay.concatenate(data, axis=1)
concat = relay.Function(relay.analysis.free_vars(z), z)
mod = tvm.IRModule.from_expr(concat)


concat_pattern = is_op('concatenate')(wildcard())
mod = relay.transform.MergeComposite([('concat_pattern', concat_pattern)])(mod)
print(mod)

Output:

def @main(%in_0: Tensor[(1, 4, 4, 3), float32], %in_1: Tensor[(1, 4, 4, 3), float32]) -> Tensor[(1, 8, 4, 3), float32] {
  %0 = (%in_0, %in_1);
  %1 = fn (%FunctionVar_0_0: (Tensor[(1, 4, 4, 3), float32], Tensor[(1, 4, 4, 3), float32]), PartitionedFromPattern="concatenate_", Composite="concat_pattern", Primitive=1) -> Tensor[(1, 8, 4, 3), float32] {
    concatenate(%FunctionVar_0_0, axis=1) /* ty=Tensor[(1, 8, 4, 3), float32] */
  };
  %1(%0) /* ty=Tensor[(1, 8, 4, 3), float32] */
}

Ah, I see, the wildcard becomes a variable for partitioning. That makes sense.

Give me a few hours, Iā€™ll try to get this done around meetings this morning.

1 Like

Sorry for the delay, meetings ate up my morning. Iā€™ve implemented my initial idea, and while it passes the matching test, it doesnā€™t pass the partitioning test, just like is_op("concatenate")(wildcard())

This is happening because in parititioning, Iā€™m assuming that patterns without inputs are themselves inputs to the subgraph, and so I treat them as inputs to the partitioned function.

Changing that assumption to bring those matches into the partitioned function is going to break a lot of other things, so I donā€™t think thatā€™s feasible in the short term.

Your workaround might be the best thing in the short term, Iā€™d need to think about how to do support for a regex style ā€œ+ā€ in a tuple, Iā€™m not sure what the knock-on effects would be.

1 Like

Thatā€™s all right. That input assumption is good as is. Maybe the solution would look like is_op("concatenate")(is_tuple(wildcard())) where a wildcard() input to tuple captures any type of input into the tuple or maybe is_tuple(wildcard()) looks for any tuple operation

Hmm, I will keep stewing. Thatā€™s an interesting idea, but I think itā€™s going to run into type issues on the C++ side.

This works, but I donā€™t love it. Thoughts?

1 Like

It looks fine. Maybe instead of None we just leave the field blank?

is_tuple(None) ā†’ is_tuple() .

To me None seems to imply an empty tuple. It also would be consistent with is_constant() and wildcard()

So, back to thinking about this this morning. @jeff815 @slyubomirsky

We basically have three places we might want to match any number of inputs: Call, Function, and Tuple.

@masahi commended a Many Pattern, but since we arenā€™t trying to match The IR itself, weā€™re trying to match a variable-size container in the IR, the lock-step recursive matching the the pattern matcher breaks with the introduction of this as a Pattern. We went on do discover what he actually needs is recursive matching, which is a bigger problem.

I can do what I did in that PR for Call and Function, it just requries copy/pasting this section of the code for everything weā€™d want to support.

      auto tuple = node->ref_.as<TuplePatternNode>();
      if (tuple && !tuple->fields.defined()) {
        if (node_map.count(node->ref_)) {
          auto matches = node_map[node->ref_];
          for (auto match : matches) {
            for (auto input : match.as<TupleNode>()->fields) {
              make_input(input);
            }
          }
        }

which is an ugly design, but doable.

The other option, I think, is to take this kind of ultility, which you both seem to be using:

def _generate_tuple_pattern():
    """Generates a pattern for tuples that can can take in 1 to 20 inputs"""
    pattern = is_tuple([wildcard()])
    for i in range(2, 20):
        pattern = pattern | is_tuple([wildcard() for j in range(i)])
    return pattern

And codify it a little to generate all three nodes. That, unfortunately, does incur some additional complexity on the frontend and some additional runtime in matching the alternative patterns.

What do you think? I think Iā€™m leaning towards doing the copy/pasting, itā€™s a little gross in the backend, but not the end of the world.

Is there a way to clean up the ā€œuglinessā€? A reduction in complexity and runtime is more valuable i think.

I think at the end of the day itā€™s a case of runtime type inference and specialization. :slight_smile: Iā€™m not sure how to clean that up any more.

Ok. Im fine with the copy/pasting route then.

For what itā€™s worth (as I mentioned privately), I donā€™t think having to use the pattern (()|(a,)|(a,b)|(a,b,c)|...) you described is a serious shortcoming, so I favor whatever least complicates the implementation.

So, it sounds like Jeff wants the simplification in the frontend, and Steven wants the simplification in the backend. @masahi, do you have a tie breaking vote?

Iā€™m ok with copy pasting, and hopefully we can come up with a clean implementation using some C++ tricks.

Soā€¦this got lost, but I spent some time today catching up:

Sorry for the delay

2 Likes