How to fuseOPs concat([op0, op1])

My relax ir as:

and I need to fuse Ops in the red box as: then I write a pattern and call relax.transform.FuseOpsByPattern pass: the result is: which is not a well formed module.

from tvm import relax
from tvm.script.parser import relax as R
from tvm.relax.dpl import is_op, wildcard, is_tuple

def _cocnat_pattern():
    dq0 = is_op("relax.abs")(wildcard())
    dq1 = is_op("relax.abs")(wildcard())
    inp = is_tuple([dq0, dq1])
    out = is_op("relax.concat")([dq0, dq1])
    annotations = {"root": out}

    return ("My_fused.concat_abs", out, annotations)

def test():
    bb = relax.BlockBuilder()
    x = relax.Var("x", R.Tensor((100,), "int8"))
    y = relax.Var("y", R.Tensor((100,), "int8"))
    with bb.function("main", [x, y]):
        with bb.dataflow():
            lv = relax.op.abs(x)
            lv1 = relax.op.abs(y)
            lv2 = relax.op.concat([lv, lv1])
            lv3 = relax.op.nn.relu(lv2)
            gv = bb.emit_output(lv3)
        bb.emit_func_output(gv)

    mod = bb.get()
    assert relax.analysis.well_formed(mod)
    pattern_table = (_cocnat_pattern(),)
    mod1 = relax.transform.FuseOpsByPattern(pattern_table)(mod)
    assert relax.analysis.well_formed(mod1)

if __name__ == "__main__":
    test()

I’ve found a method to resolve this problem:[Relax] Fix issue in fuse concat ops by pattern by cccxinli · Pull Request #18163 · apache/tvm · GitHub

2 Likes