[Unity] relax.PyExprMutator loss PrimFunc info

I designed a relax.PyExprMutator-likes mutator, and tested on a relax.IRModule,just to found the original PrimFunc info has lost, and the there are logic errors in relax main functions, these is no add primfunc,which should be cls.add。

the complete code are:

    import tvm 
    from tvm import relay, relax

    device = tvm.cpu(0)
    target = tvm.target.Target.from_device(device)

    @relax.expr_functor.mutator
    class Foo(relax.PyExprMutator):
        def __init__(self, mod, target: tvm.target.Target) -> None:
            super().__init__()
            self.mod_ = mod
            self.target_ = target
            
        def transform(self):
            for global_var, func in self.mod_.functions.items():
                if not isinstance(func, relax.Function):
                    continue
                # avoid already fused primitive functions
                if func.attrs is not None and "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
                    continue
                updated_func = self.visit_expr(func)
                updated_func = relax.analysis.remove_all_unused(updated_func)
                self.builder_.update_func(global_var, updated_func)

            return self.builder_.finalize()
        
    from tvm.script import ir as I, relax as R, tir as T 
    @I.ir_module
    class Module:
        @T.prim_func(private=True)
        def add(x: T.Buffer((T.int64(2), T.int64(3)), "float16"), x_1: T.Buffer((T.int64(2), T.int64(3)), "float16"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float16")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                with T.block("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(x[v_ax0, v_ax1], x_1[v_ax0, v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + x_1[v_ax0, v_ax1]

        @T.prim_func(private=True)
        def multiply(x: T.Buffer((T.int64(2), T.int64(3)), "float16"), x_1: T.Buffer((T.int64(2), T.int64(3)), "float16"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float16")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                with T.block("T_multiply"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(x[v_ax0, v_ax1], x_1[v_ax0, v_ax1])
                    T.writes(T_multiply[v_ax0, v_ax1])
                    T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * x_1[v_ax0, v_ax1]

        @R.function
        def main(x: R.Tensor((2, 3), dtype="float16"), x_1: R.Tensor((2, 3), dtype="float16")) -> R.Tensor((2, 3), dtype="float16"):
            cls = Module
            with R.dataflow():
                lv = R.call_tir(cls.add, (x, x_1), out_sinfo=R.Tensor((2, 3), dtype="float16"))
                lv1: R.Tensor((2, 3), dtype="float16") = R.multiply(x, x_1)
                gv: R.Tensor((2, 3), dtype="float16") = lv
                R.output(gv)
            return gv


    mod = Module 
    mod = Foo(mod, target).transform()
    print(mod)

after run the testing code, i got the following relax IRModule, which add function is strange, as there is no definition anywhere: root@4614735cad4e:/home/$ ./test # from tvm.script import ir as I # from tvm.script import relax as R

    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor((2, 3), dtype="float16"), x_1: R.Tensor((2, 3), dtype="float16")) -> R.Tensor((2, 3), dtype="float16"):
            with R.dataflow():
                lv = R.call_tir(add, (x, x_1), out_sinfo=R.Tensor((2, 3), dtype="float16"))
                gv: R.Tensor((2, 3), dtype="float16") = lv
                R.output(gv)
            return gv

You need to pass mod to PyExprMutator base class to initialize the self.builder_ properly. Something like this super().__init__(mod).

thanks very much, it work!