How do I add TIR to a BlockBuilder
?
What I am trying to achieve is a simple mapping of relax operators to handcrafted tir.PrimFunc
during LegalizeOps
, I haven’t managed to get it to work yet.
Here is my attempt:
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T
@T.prim_func
def matmul_fp32(
A_handle: T.handle,
B_handle: T.handle,
out_handle: T.handle,
M: T.int64,
N: T.int64,
K: T.int64,
) -> None:
# function attr dict
T.func_attr({"tir.noalias": True})
A = T.match_buffer(A_handle, [M, K], dtype="float32")
B = T.match_buffer(B_handle, [K, N], dtype="float32")
out = T.match_buffer(out_handle, [M, N], dtype="float32")
with T.block("root"):
T.reads(A[0:M, 0:K], B[0:K, 0:N])
T.writes(C[0:M, 0:N])
T.evaluate(
T.call_extern(
"matmul_f32",
out.access_ptr("w"),
A.access_ptr("r"),
B.access_ptr("r"),
M,
N,
K,
dtype="float32",
)
)
A = relax.Var("A", relax.TensorStructInfo((128, 128), "float32"))
B = relax.Var("B", relax.TensorStructInfo((128, 128), "float32"))
C = relax.Var("B", relax.TensorStructInfo((128, 128), "float32"))
bb = relax.BlockBuilder()
with bb.function("main"):
with R.dataflow():
C = bb.emit(R.call_tir(matmul_fp32, [A, B], C, dtype="float32"))
out = bb.emit_func_output(C)
bb.emit_func_output(out, params=[A, B])
The error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[41], line 3
1 with bb.function("main"):
----> 3 with R.dataflow():
5 C = bb.emit(R.call_tir(matmul_rrr_fp32, [A, B], C, dtype="float32"))
6 out = bb.emit_func_output(C)
File ~/tvm/python/tvm/script/ir_builder/base.py:64, in IRBuilderFrame.__enter__(self)
63 def __enter__(self) -> "IRBuilderFrame":
---> 64 _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
65 return self
File ~/tvm/python/tvm/_ffi/_ctypes/packed_func.py:237, in PackedFuncBase.__call__(self, *args)
225 ret_tcode = ctypes.c_int()
226 if (
227 _LIB.TVMFuncCall(
228 self.handle,
(...)
235 != 0
236 ):
--> 237 raise get_last_ffi_error()
238 _ = temp_args
239 _ = args
ValueError: Traceback (most recent call last):
File "/Users/user/tvm/src/script/ir_builder/base.cc", line 76
ValueError: Check failed: (!stack->empty()) is false: No builder in current scope