Rewriting a function in a Relay graph failed

Hello, I’m trying to partition a Relay graph into some functions and rewrite them but fail. Here’s a minimum working example:

import tvm
import tvm.relay as relay
from tvm.relay.dataflow_pattern import wildcard, is_op, rewrite, DFPatternCallback, FunctionPattern

class TestCallback(DFPatternCallback):
    def __init__(self):
        super(TestCallback, self).__init__()
        self.x = wildcard()
        self.y = wildcard()

        pattern = is_op('add')(self.x, self.y)
        pattern = FunctionPattern([wildcard(), wildcard()], pattern)
        self.pattern = pattern

    def callback(self, pre, post, node_map):
        print('here')
        x = node_map[self.x][0]
        y = node_map[self.y][0]
        return x - y

x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
expr = (x + y) * z

p = wildcard() + wildcard()
fp = FunctionPattern([wildcard(), wildcard()], p)
print(expr)
expr_p = p.partition(expr)
print(expr_p)
expr_r = rewrite(TestCallback(), expr_p)
print(expr_r)

The third print statement print the same output with the second one as TestCallback fails to match the FunctionPattern that contains one add op. Anyone can help?

Thanks in advance!

1 Like

Anyone could kindly help me out?

Change

pattern = FunctionPattern([wildcard(), wildcard()], pattern)

To

pattern = FunctionPattern([wildcard(), wildcard()], pattern)(wildcard(), wildcard())

Could solve your issue. Your pattern should also “call” the function instead of just defining it.

1 Like

That solves my problem. Thank you so much!