Hello, I’m trying to partition a Relay graph into some functions and rewrite them but fail. Here’s a minimum working example:
import tvm
import tvm.relay as relay
from tvm.relay.dataflow_pattern import wildcard, is_op, rewrite, DFPatternCallback, FunctionPattern
class TestCallback(DFPatternCallback):
def __init__(self):
super(TestCallback, self).__init__()
self.x = wildcard()
self.y = wildcard()
pattern = is_op('add')(self.x, self.y)
pattern = FunctionPattern([wildcard(), wildcard()], pattern)
self.pattern = pattern
def callback(self, pre, post, node_map):
print('here')
x = node_map[self.x][0]
y = node_map[self.y][0]
return x - y
x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
expr = (x + y) * z
p = wildcard() + wildcard()
fp = FunctionPattern([wildcard(), wildcard()], p)
print(expr)
expr_p = p.partition(expr)
print(expr_p)
expr_r = rewrite(TestCallback(), expr_p)
print(expr_r)
The third print statement print the same output with the second one as TestCallback fails to match the FunctionPattern
that contains one add
op. Anyone can help?
Thanks in advance!