I was trying to do something like this, basically have my own visitor that will do some work on its own and use the superclass’s visitor to do the rest:
@relax.expr_functor.mutator
class MyMutator(relax.PyExprMutator):
...
def visit_dataflow_block_(self, block):
# visit bindings...
# Do the equivalent of C++'s PyExprMutator::visit_dataflow_block(block):
super(relax.PyExprMutator, self).visit_dataflow_block(block) # Did not work
def visit_call_(self, call):
# Not called if I have my own block visitor.
This didn’t work because the superclass didn’t have any visitor methods in it (I guess the decorator did something here). Without calling the superclass’s visitor, I don’t know how to get it to call visitors of the elements of the block.
Does PyExprMutator
even support this?
1 Like
I guess this is a bug of PyExprMutator (as well as visitor). We once observed the endless recursion for visit_var_binding_
and had a fix here https://github.com/apache/tvm/pull/14754. It looks to me that we can fix the issue of visit_dataflow_block
in a similar way.
1 Like
CC @leshenj15 the original author of PyExprMutator 
1 Like
Thank @kparzysz for sharing your experience! PR to fix this: https://github.com/apache/tvm/pull/15189
2 Likes
Thank you @leshenj15 for the quick turnaround!
1 Like
got similar issue under rewrite_call
from tvm.relax.dpl import *
from tvm.script import relax as R
def test_rewrite_simple():
@R.function
def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
with R.dataflow():
y = R.multiply(x, R.const(2, "float32"))
R.output(y)
return y
x = wildcard()
y = wildcard()
pattern = is_op("relax.multiply")(x, y)
def rewriter_passed_case_1(_, matchings):
print(_)
t = matchings[x]
t = R.multiply(t, R.const(2, "float32"))
return t
def rewriter_passed_case_2(_, matchings):
print(_)
t = matchings[x]
t = R.layout_transform(matchings[x], index_map=lambda i, j: (i, j))
t = R.add(t, R.const(2, "float32"))
return t
def rewriter_issue_case(_, matchings):
print(_)
t = matchings[x]
t = R.layout_transform(matchings[x], index_map=lambda i, j: (i, j))
t = R.multiply(t, R.const(2, "float32"))
return t
rewritten = rewrite_call(pattern, rewriter_passed_case_1, main)
print("passed case 1: ", rewritten)
rewritten = rewrite_call(pattern, rewriter_passed_case_2, main)
print("passed case 2: ", rewritten)
rewritten = rewrite_call(pattern, rewriter_issue_case, main)
print("passed issue case: ", rewritten)
test_rewrite_simple()
But the same logic passed when I modified it into PyExprMutator, do rewrite_call has the same problem?