Rewriting a function in a Relay graph failed

Hello, I’m trying to partition a Relay graph into some functions and rewrite them but fail. Here’s a minimum working example:

import tvm
import tvm.relay as relay
from tvm.relay.dataflow_pattern import wildcard, is_op, rewrite, DFPatternCallback, FunctionPattern

class TestCallback(DFPatternCallback):
    def __init__(self):
        super(TestCallback, self).__init__()
        self.x = wildcard()
        self.y = wildcard()

        pattern = is_op('add')(self.x, self.y)
        pattern = FunctionPattern([wildcard(), wildcard()], pattern)
        self.pattern = pattern

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        y = node_map[self.y][0]
        return x - y

x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
expr = (x + y) * z

p = wildcard() + wildcard()
fp = FunctionPattern([wildcard(), wildcard()], p)
expr_p = p.partition(expr)
expr_r = rewrite(TestCallback(), expr_p)

The third print statement print the same output with the second one as TestCallback fails to match the FunctionPattern that contains one add op. Anyone can help?

Thanks in advance!

Anyone could kindly help me out?


pattern = FunctionPattern([wildcard(), wildcard()], pattern)


pattern = FunctionPattern([wildcard(), wildcard()], pattern)(wildcard(), wildcard())

Could solve your issue. Your pattern should also “call” the function instead of just defining it.

That solves my problem. Thank you so much!