Current TypePattern requires type in pattern matches type in expr exactly. Is there a way to match only according to element type. For example, match a fp16 add?
has_type
is not necessary to match the final expression. It can also be used like the following to check the type of an op:
in1 = wildcard()
in2 = wildcard()
pat = is_op('add')(in1, in2).has_type(relay.TensorType((10, 10), 'float32'))
x = relay.var('x', shape=(10, 10), dtype='float32')
pat.match(relay.add(x, x))
This can be extended to a more complex pattern like
in1 = wildcard()
in2 = wildcard()
add = is_op('add')(in1, in2).has_type(relay.TensorType((10, 10), 'float32'))
mul = is_op('multiply')(add, add)
In this case, you only match the the type of add
but do not care mul
, although in this case mul
must be float32
, too.
Can it match shape other than (10,10)?
Of course. Shape is a part of the type.
Iâm afraid I havenât made my point clear. By using has_type, we can match to a specific dtype, but the shape of tensor is also fixed. By using wildcard(), we can match to (AnyShape, AnyType). How to achieve something like has_type((AnyShape, âfloat32â))?
in1 = wildcard()
in2 = wildcard()
pat = is_op('add')(in1, in2).has_type(relay.TensorType((10, 10), 'float32'))
x = relay.var('x', shape=(10, 10), dtype='float32')
pat.match(relay.add(x, x))
Thanks
I see your point and it seems fair.
In the current implementation, one solution I can think of is leveraging the check
function:
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def check(pre):
return (pre.args[0].checked_type.dtype == 'float32' and
pre.args[1].checked_type.dtype == 'float32')
pat = is_op('add')(wildcard(), wildcard())
x = relay.var('x', shape=(10, 10), dtype='float32')
out = relay.add(x, x)
func = relay.Function([x], out)
mod = tvm.IRModule()
mod['main'] = func
mod = relay.transform.InferType()(mod)
print(pat.partition(mod['main'].body, check=check))
In short, you can implement a check function which does any forms of checking by looking into the matched subgraph.
In case you only need to know if matching or not and do not want to partition the graph, you can use rewrite
to mimic the above functionality:
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
class MyCallback(DFPatternCallback):
def __init__(self):
self.in1 = wildcard()
self.in2 = wildcard()
self.pattern = is_op('add')(self.in1, self.in2)
self.match = False
def callback(self, pre, post, node_map):
if (node_map[self.in1][0].checked_type.dtype == 'float32' and
node_map[self.in2][0].checked_type.dtype == 'float32'):
self.match = True
return post
x = relay.var('x', shape=(10, 10), dtype='float32')
out = relay.add(x, x)
func = relay.Function([x], out)
mod = tvm.IRModule()
mod['main'] = func
mod = relay.transform.InferType()(mod)
callback = MyCallback()
rewrite(callback, mod['main'].body)
print(callback.match)
When matching the pattern, rewrite
will call the callback function for graph mutation. You can of course add any checks here and maintain your own âmatchâ status.
While the above solutions are working, there are definitely imperfect, especially when the pattern is complex. In long term, we may want to support partial type matching in the pattern language.
cc @mbrookhart
Thanks a lot for this detail sample!
My apologies! I somehow missed this last week.
Yeah, the current TypePattern is matching the full type via StructuralEqual. One possibility to clean this up slightly is to add a rule of:
- If itâs a TensorType
- and the patternâs shape is ()
- only check the dtype
That would only take a handful of lines in the C++ matcher to implement. Itâs a little specialized, but easier that doing it with callbacks.
Make sense. The ideal interface for this case would be leveraging has_type as other type matching. We may need new patterns like AnyShape, or support Wildcard in tensor shapes (seems much harder to me).
Yeah, itâs a bit complicated. The current pattern uses the tvm::ir::Type* classes, so you can match number of versions of types, but as this question reveals, we may want a finer granularity on some types.
Unfortunately, since weâre using a lower level Type object, we wonât be able to embed Pattern-specific things into that Type (like wildcard or AnyShape)
I think our options are to add rules like above, or extend the pattern language to include things like âShapePatternâ and âDTypePatternâ at a more granular level.
What do you guys thing? Which would be easier to use?
I personally like ShapePattern
and DTypePattern
more as they are more straightforward in the pattern language, but Iâd also like to see otherâs opinions.
I would also agree to add these two patterns instead of tricking the matching APIs/rules
Iâll throw together a PR with those extensions. Thanks!