[SOLVED] Beginner help with FuseOpsByPattern

I can’t get FuseOpsByPattern to work with this simple example. I am expecting my pattern to match once an a separate function being created for it, as well as a call to it in main. What am I missing?

from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm import relax

from tvm.relax.dpl import is_op, wildcard
from tvm.relax.transform import FuseOpsByPattern


@I.ir_module
class Model:
    @R.function
    def main(
        x: R.Tensor((128, 128), "float32"),
        A: R.Tensor((128, 128), "float32"),
        B: R.Tensor((128, 128), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        R.func_attr({"num_input": 1})
        lv0 = R.matmul(x, A)
        Y = R.nn.relu(lv0)
        Z = R.matmul(Y, B)
        return Z

def mlp_pattern():
    x = wildcard()
    w1 = wildcard()
    
    y = is_op("relax.matmul")(x, w1)
    y = is_op("relax.nn.relu")(y)

    return y

mod = FuseOpsByPattern(
    patterns=[("mlp_block", mlp_pattern())],
    annotate_codegen=False,
    bind_constants=False,
)(Model)

mod.show()

# from tvm.script import ir as I
# from tvm.script import relax as R
#
# @I.ir_module
# class Module:
#     @R.function
#     def main(x: R.Tensor((128, 128), dtype="float32"), A: R.Tensor((128, 128), dtype="float32"), B: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
#         R.func_attr({"num_input": 1})
#         lv0: R.Tensor((128, 128), dtype="float32") = R.matmul(x, A, out_dtype="void")
#         Y: R.Tensor((128, 128), dtype="float32") = R.nn.relu(lv0)
#         Z: R.Tensor((128, 128), dtype="float32") = R.matmul(Y, B, out_dtype="void")
#         return Z

I try your code with the following modifications and get what you maybe want:

with R.dataflow():
    lv0 = R.matmul(x, A)
    Y = R.nn.relu(lv0)
    Z = R.matmul(Y, B)
    R.output(Z)
return Z

just wrap your code with R.dataflow(), and the results are as follows:

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def fused_relax_matmul_relax_nn_relu(x: R.Tensor((128, 128), dtype="float32"), A: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        R.func_attr({"Composite": "mlp_block", "Primitive": 1})
        with R.dataflow():
            lv0: R.Tensor((128, 128), dtype="float32") = R.matmul(x, A, out_dtype="void")
            gv: R.Tensor((128, 128), dtype="float32") = R.nn.relu(lv0)
            R.output(gv)
        return gv

    @R.function
    def main(x: R.Tensor((128, 128), dtype="float32"), A: R.Tensor((128, 128), dtype="float32"), B: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((128, 128), dtype="float32") = cls.fused_relax_matmul_relax_nn_relu(x, A)
            Z: R.Tensor((128, 128), dtype="float32") = R.matmul(lv, B, out_dtype="void")
            R.output(Z)
        return Z
1 Like