Recursive visiting in `PyExprMutator`

I was trying to do something like this, basically have my own visitor that will do some work on its own and use the superclass’s visitor to do the rest:

@relax.expr_functor.mutator
class MyMutator(relax.PyExprMutator):
    ...
    def visit_dataflow_block_(self, block):
        # visit bindings...
        # Do the equivalent of C++'s PyExprMutator::visit_dataflow_block(block):
        super(relax.PyExprMutator, self).visit_dataflow_block(block)   # Did not work

    def visit_call_(self, call):
        # Not called if I have my own block visitor.

This didn’t work because the superclass didn’t have any visitor methods in it (I guess the decorator did something here). Without calling the superclass’s visitor, I don’t know how to get it to call visitors of the elements of the block.

Does PyExprMutator even support this?

1 Like

I guess this is a bug of PyExprMutator (as well as visitor). We once observed the endless recursion for visit_var_binding_ and had a fix here https://github.com/apache/tvm/pull/14754. It looks to me that we can fix the issue of visit_dataflow_block in a similar way.

1 Like

CC @leshenj15 the original author of PyExprMutator :slight_smile:

1 Like

Thank @kparzysz for sharing your experience! PR to fix this: https://github.com/apache/tvm/pull/15189

2 Likes

Thank you @leshenj15 for the quick turnaround!

1 Like

got similar issue under rewrite_call

from tvm.relax.dpl import *
from tvm.script import relax as R

def test_rewrite_simple():

    @R.function
    def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
        with R.dataflow():
            y = R.multiply(x, R.const(2, "float32"))
            R.output(y)
        return y


    x = wildcard()
    y = wildcard()
    pattern = is_op("relax.multiply")(x, y)
    
    def rewriter_passed_case_1(_, matchings):
        print(_)
        t = matchings[x]
        t = R.multiply(t, R.const(2, "float32"))
        return t
    
    def rewriter_passed_case_2(_, matchings):
        print(_)
        t = matchings[x]
        t = R.layout_transform(matchings[x], index_map=lambda i, j: (i, j))
        t = R.add(t, R.const(2, "float32"))
        return t
    
    def rewriter_issue_case(_, matchings):
        print(_)
        t = matchings[x]
        t = R.layout_transform(matchings[x], index_map=lambda i, j: (i, j))
        t = R.multiply(t, R.const(2, "float32"))
        return t
    rewritten = rewrite_call(pattern, rewriter_passed_case_1, main)
    print("passed case 1: ", rewritten)
    rewritten = rewrite_call(pattern, rewriter_passed_case_2, main)
    print("passed case 2: ", rewritten)
    rewritten = rewrite_call(pattern, rewriter_issue_case, main)
    print("passed issue case: ", rewritten)
    

test_rewrite_simple()

But the same logic passed when I modified it into PyExprMutator, do rewrite_call has the same problem?