Pattern matching: does `has_attr` support matching `kernel_size` for nn.conv2d?

As the title says. I’m trying to match a conv2d op with pattern = is_op('nn.conv2d')(data, weight).has_attr({"kernel_size": [3, 3]}) but with no luck. I suspect it’s the list [3, 3] that is not correctly passed. Seems like when a conv2d is matched, the kernel_size field in its attrs has a type of tvm.ir.Array, but I’m not able to create an array from a list as it says Array() takes no arguments.

Anyone can help?

I think this is a limitation of the pattern matching. We don’t match the attributes with the type other than int/float/string.

cc @mbrookhart

Hmm, this is interesting :slight_smile: When we started the attr pattern, we expected it to mostly be used for op attributes, the call attributes have grown since then.

Could you post a minimally reproducible example? I think we probably need to add kTVMObjectHandle to this switch statement:

I’m not sure where you’re attempting to do this

Is that something you’re trying to add to the pattern matcher?

I have one in case you need:

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

x = relay.var("x")
y = relay.var("y")

is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
out = relay.op.nn.conv2d(x, y, kernel_size=[3,3])
print(is_conv2d.match(out)) # 0

I also found the code you pointed out, and realized that we probably don’t have that support.

I would expect this to error out :confused:

I guess the reason of not error out is that the value is a kTVMObjectHandle but not a String object, so the function just returns false.

Yep, I see, I’ll try to fix quickly.

2 Likes

@mbrookhart @comaniac Thank you both! You guys are amazing.

(I feel I’m a bit greedy here:) ) Here comes a follow-up question: is it possible to also support matching a boolean condition, e.g. has_attr({"channels": value > 32}), something like that?

It would be too complex for the current design IMHO, as it needs to evaluate the condition expression. The solution I would suggest for this case is first matching conv2d with arbitrary channels, and only rewrite the conv2d for channel > 32 in your rewrite callback.

I see, thanks for the answer. Actually I realize this can be solved by applying a check function when the graph is being partitioned, so that nothing should be worried during rewriting.