BlockBuilder: emitting TIR straight away [SOLVED]

How do I add TIR to a BlockBuilder?

What I am trying to achieve is a simple mapping of relax operators to handcrafted tir.PrimFunc during LegalizeOps, I haven’t managed to get it to work yet.

Here is my attempt:

import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

@T.prim_func
def matmul_fp32(
    A_handle: T.handle,
    B_handle: T.handle,
    out_handle: T.handle,
    M: T.int64,
    N: T.int64,
    K: T.int64,
) -> None:
    # function attr dict
    T.func_attr({"tir.noalias": True})
    A = T.match_buffer(A_handle, [M, K], dtype="float32")
    B = T.match_buffer(B_handle, [K, N], dtype="float32")
    out = T.match_buffer(out_handle, [M, N], dtype="float32")

    with T.block("root"):
        T.reads(A[0:M, 0:K], B[0:K, 0:N])
        T.writes(C[0:M, 0:N])
        T.evaluate(
            T.call_extern(
                "matmul_f32",
                out.access_ptr("w"),
                A.access_ptr("r"),
                B.access_ptr("r"),
                M,
                N,
                K,
                dtype="float32",
            )
        )

A = relax.Var("A", relax.TensorStructInfo((128, 128), "float32"))
B = relax.Var("B", relax.TensorStructInfo((128, 128), "float32"))
C = relax.Var("B", relax.TensorStructInfo((128, 128), "float32"))

bb = relax.BlockBuilder()

with bb.function("main"):
    with R.dataflow():
        C = bb.emit(R.call_tir(matmul_fp32, [A, B], C, dtype="float32"))
        out = bb.emit_func_output(C)
    bb.emit_func_output(out, params=[A, B])

The error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[41], line 3
      1 with bb.function("main"):
----> 3    with R.dataflow():
      5        C = bb.emit(R.call_tir(matmul_rrr_fp32, [A, B], C, dtype="float32"))
      6        out = bb.emit_func_output(C)

File ~/tvm/python/tvm/script/ir_builder/base.py:64, in IRBuilderFrame.__enter__(self)
     63 def __enter__(self) -> "IRBuilderFrame":
---> 64     _ffi_api.IRBuilderFrameEnter(self)  # type: ignore[attr-defined] # pylint: disable=no-member
     65     return self

File ~/tvm/python/tvm/_ffi/_ctypes/packed_func.py:237, in PackedFuncBase.__call__(self, *args)
    225 ret_tcode = ctypes.c_int()
    226 if (
    227     _LIB.TVMFuncCall(
    228         self.handle,
   (...)
    235     != 0
    236 ):
--> 237     raise get_last_ffi_error()
    238 _ = temp_args
    239 _ = args

ValueError: Traceback (most recent call last):
  File "/Users/user/tvm/src/script/ir_builder/base.cc", line 76
ValueError: Check failed: (!stack->empty()) is false: No builder in current scope

Please try:

with bb.function("main"):
    with bb.dataflow():
        gv = bb.add_func(matmul_fp32, "matmul_fp32")
        C = bb.emit_output(
            relax.call_tir(
                gv,
                [A, B],
                relax.TensorStructInfo((128, 128), "float32"),
            )
        )
    bb.emit_func_output(C, params=[A, B])

Writing whole IRModule with TVMScript also helps :slight_smile:

1 Like

Thank you that’s working perfectly. Also the error makes more sense to me now.

Would it help in the case of writing handcrafted kernels for legalisation?

I am sharing the solution to my initial problem if it can help other newbies. To recap, my first goal was to map high level ops straight to handcrafted TIR, this can be done with LegalizeOps like this:

@tvm.script.ir_module
class Module:
    @R.function
    def main(A: R.Tensor((128, 128), "float32"), B: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
        y = R.matmul(A, B)
        z = R.matmul(A, y)
        return z

def customize_legalize_matmul(bb, call):
    # using the previous matmul_fp32 definition
    gv = bb.add_func(matmul_fp32, "matmul_fp32")
    A, B = call.args[0], call.args[1]
    return bb.emit(
            relax.call_tir(
                gv,
                [A, B],
                relax.TensorStructInfo((128, 128), "float32"),
            )
        )

mod = relax.transform.LegalizeOps({"relax.matmul": customize_legalize_matmul})(Module)
mod.show()

output:

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R


@I.ir_module
class Module:
    @T.prim_func
    def matmul_fp32(
        var_a: T.handle,
        var_b: T.handle,
        var_c: T.handle,
        M: T.int64,
        N: T.int64,
        K: T.int64,
    ):
        T.func_attr({"tir.noalias": True})
        A = T.match_buffer(var_a, (M, K))
        B = T.match_buffer(var_b, (K, N))
        C = T.match_buffer(var_c, (M, N))
        with T.block("root"):
            T.reads(A[0:M, 0:K], B[0:K, 0:N])
            T.writes(C[0:M, 0:N])
            T.call_extern(
                "float32",
                "matmul_f32",
                T.tvm_access_ptr(
                    T.type_annotation("float32"), C.data, T.int64(0), M * N, 2
                ),
                T.tvm_access_ptr(
                    T.type_annotation("float32"), A.data, T.int64(0), M * K, 1
                ),
                T.tvm_access_ptr(
                    T.type_annotation("float32"), B.data, T.int64(0), K * N, 1
                ),
                M,
                N,
                K,
            )

    @R.function
    def main(
        A: R.Tensor((128, 128), dtype="float32"),
        B: R.Tensor((128, 128), dtype="float32"),
    ) -> R.Tensor((128, 128), dtype="float32"):
        cls = Module
        gv = R.call_tir(
            cls.matmul_fp32, (A, B), out_sinfo=R.Tensor((128, 128), dtype="float32")
        )
        y: R.Tensor((128, 128), dtype="float32") = gv
        gv1 = R.call_tir(
            cls.matmul_fp32, (A, y), out_sinfo=R.Tensor((128, 128), dtype="float32")
        )
        z: R.Tensor((128, 128), dtype="float32") = gv1
        return z
1 Like