I want to integrate a TIR schedule into an IRModule so that I can later use relay.build
. I found this thread that introduced me to call_lowered
so I wanted to try it, this is the code I have:
mod = tvm.ir.IRModule()
mainFnVar = relay.GlobalVar("main")
call_lowered = tvm.relay.op.get("call_lowered")
ifm = relay.var("ifm", shape=inp_shape)
wgt = relay.var("wgt", shape=wght_shape)
args = relay.Tuple([ifm, wgt])
sb = relay.ScopeBuilder()
vortexFnVar = relay.GlobalVar("vortexMatmulTir")
attrs = tvm.ir.make_node(
"relay.attrs.CallLoweredAttrs", **{"metadata": {"relay_attrs": ""}}
)
tmp = sb.let("result", relay.Call(call_lowered, [vortexFnVar, args], attrs=attrs))
sb.ret(tmp)
mainFn = relay.Function([ifm, wgt], sb.get(), tvm.ir.TensorType(out_shape, dtype="int8"))
mod.update_func(mainFnVar, mainFn)
mod.update_func(vortexFnVar, sch.mod["main"])
The IRModule also looks as I would expect:
@vortexMatmulTir = primfn(var_a_in: handle, var_b_in: handle, var_res: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {a_in: Buffer(a_in_1: Pointer(global int8), int8, [128, 1024], []),
b_in: Buffer(b_in_1: Pointer(global int8), int8, [1024, 512], []),
res: Buffer(res_1: Pointer(global int8), int8, [128, 512], [])}
buffer_map = {var_a_in: a_in, var_b_in: b_in, var_res: res} {
block([], "root") {
tir.reads([])
tir.writes([])
{
block([1, 1], "res_init_o") as [v_i_o, v_j_o] {
bind(v_i_o, 0)
bind(v_j_o, 0)
tir.reads([])
tir.writes([res[0:128, 0:512]])
C = match_buffer(res[0:128, 0:512])
0
block([1, 1, tir.reduce_axis(0, 1)], "res_update_o") as [v_i_o_1, v_j_o_1, v_k_o] {
bind(v_i_o_1, 0)
bind(v_j_o_1, 0)
bind(v_k_o, 0)
tir.reads([res[0:128, 0:512], a_in[0:128, 0:1024], b_in[0:1024, 0:512]])
tir.writes([res[0:128, 0:512]])
A = match_buffer(a_in[0:128, 0:1024])
B = match_buffer(b_in[0:1024, 0:512])
C_1 = match_buffer(res[0:128, 0:512])
@tir.call_extern("tiled_matmul", dtype=)
}
}
def @main(%ifm: Tensor[(128, 1024), float32], %wgt: Tensor[(1024, 512), float32]) -> Tensor[(128, 512), int8] {
%0 = (%ifm, %wgt);
let %result = call_lowered(@vortexMatmulTir, %0, metadata={"relay_attrs"=""});
%result
}
However, when I call relay.build
I still get the error I ran into when I tried to just build sch.mod
without using call_lowered
:
InternalError: Check failed: (can_dispatch(n)) is false: NodeFunctor calls un-registered function on type tir.PrimFunc
What am I doing wrong here?