I designed a relax.PyExprMutator-likes mutator, and tested on a relax.IRModule,just to found the original PrimFunc info has lost, and the there are logic errors in relax main functions, these is no add primfunc,which should be cls.add。
the complete code are:
import tvm
from tvm import relay, relax
device = tvm.cpu(0)
target = tvm.target.Target.from_device(device)
@relax.expr_functor.mutator
class Foo(relax.PyExprMutator):
def __init__(self, mod, target: tvm.target.Target) -> None:
super().__init__()
self.mod_ = mod
self.target_ = target
def transform(self):
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue
# avoid already fused primitive functions
if func.attrs is not None and "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
continue
updated_func = self.visit_expr(func)
updated_func = relax.analysis.remove_all_unused(updated_func)
self.builder_.update_func(global_var, updated_func)
return self.builder_.finalize()
from tvm.script import ir as I, relax as R, tir as T
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(x: T.Buffer((T.int64(2), T.int64(3)), "float16"), x_1: T.Buffer((T.int64(2), T.int64(3)), "float16"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(x[v_ax0, v_ax1], x_1[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + x_1[v_ax0, v_ax1]
@T.prim_func(private=True)
def multiply(x: T.Buffer((T.int64(2), T.int64(3)), "float16"), x_1: T.Buffer((T.int64(2), T.int64(3)), "float16"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(x[v_ax0, v_ax1], x_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * x_1[v_ax0, v_ax1]
@R.function
def main(x: R.Tensor((2, 3), dtype="float16"), x_1: R.Tensor((2, 3), dtype="float16")) -> R.Tensor((2, 3), dtype="float16"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.add, (x, x_1), out_sinfo=R.Tensor((2, 3), dtype="float16"))
lv1: R.Tensor((2, 3), dtype="float16") = R.multiply(x, x_1)
gv: R.Tensor((2, 3), dtype="float16") = lv
R.output(gv)
return gv
mod = Module
mod = Foo(mod, target).transform()
print(mod)