I can’t get FuseOpsByPattern
to work with this simple example. I am expecting my pattern to match once an a separate function being created for it, as well as a call to it in main
. What am I missing?
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm import relax
from tvm.relax.dpl import is_op, wildcard
from tvm.relax.transform import FuseOpsByPattern
@I.ir_module
class Model:
@R.function
def main(
x: R.Tensor((128, 128), "float32"),
A: R.Tensor((128, 128), "float32"),
B: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"num_input": 1})
lv0 = R.matmul(x, A)
Y = R.nn.relu(lv0)
Z = R.matmul(Y, B)
return Z
def mlp_pattern():
x = wildcard()
w1 = wildcard()
y = is_op("relax.matmul")(x, w1)
y = is_op("relax.nn.relu")(y)
return y
mod = FuseOpsByPattern(
patterns=[("mlp_block", mlp_pattern())],
annotate_codegen=False,
bind_constants=False,
)(Model)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
#
# @I.ir_module
# class Module:
# @R.function
# def main(x: R.Tensor((128, 128), dtype="float32"), A: R.Tensor((128, 128), dtype="float32"), B: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
# R.func_attr({"num_input": 1})
# lv0: R.Tensor((128, 128), dtype="float32") = R.matmul(x, A, out_dtype="void")
# Y: R.Tensor((128, 128), dtype="float32") = R.nn.relu(lv0)
# Z: R.Tensor((128, 128), dtype="float32") = R.matmul(Y, B, out_dtype="void")
# return Z