Calling a PrimFunc from Relay and calling relay.build

Hi everyone,

I am trying to build schedule operations on a relay graph with TIR and the MS backend. Computation is originally defined in TE and then converted to a tir.Schedule. The code looks like this:

res = te.compute(
    out_shape,
    lambda i, j: te.sum(
        inp[i, rk].astype(out_dtype) * wght[rk, j].astype(out_dtype),
        axis=[rk],
    ),
    name="res",
    tag="dense",
)

func = te.create_prim_func([inp, wght, res])
sch = tvm.tir.Schedule(func)

schedule_matmul_gemmini(sch, do_tune=False)

I found this post that hinted at call_lowered. So I added the required parts to my flow. This is the code to build the module:

# Put TIR schedule into a Relay function
mod = tvm.ir.IRModule()
mainFnVar = relay.GlobalVar("main")

call_lowered = tvm.relay.op.get("call_lowered")
ifm = relay.var("ifm", shape=inp_shape)
wgt = relay.var("wgt", shape=wght_shape)
args = relay.Tuple([ifm, wgt])

sb = relay.ScopeBuilder()
gemminiFnVar = relay.GlobalVar("gemminiMatmulTir")
attrs = tvm.ir.make_node(
    "relay.attrs.CallLoweredAttrs", **{"metadata": {"relay_attrs": ""}}
)
tmp = sb.let("result", relay.Call(call_lowered, [gemminiFnVar, args], attrs=attrs))
sb.ret(tmp)
mainFn = relay.Function([ifm, wgt], sb.get(), tvm.ir.TensorType(out_shape, dtype="int8"))

mod.update_func(mainFnVar, mainFn)
mod.update_func(gemminiFnVar, sch.mod["main"])

with tvm.transform.PassContext(config=config, opt_level=3), ms.database.ScheduleFnDatabase(schedule_fn):
    #module = tvm.build(sch.mod, [inp, wght, res], name="matmul_test", binds={inp: input_data, wght:weight_data, res: res_buf}, target=TARGET,)
    module = tvm.relay.build(sch.mod, executor=EXECUTOR, runtime=RUNTIME, target=TARGET)

However, this throws the following error trace:

tvm.error.InternalError: Traceback (most recent call last):
  13: tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  12: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  11: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::ToBasicBlockNormalForm()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::ToBasicBlockNormalForm()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  4: tvm::relay::ToBasicBlockNormalForm(tvm::IRModule const&)
  3: tvm::relay::FreeVars(tvm::RelayExpr const&)
  2: tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)
  1: tvm::relay::MixedModeVisitor::VisitLeaf(tvm::RelayExpr const&)
  0: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  File "/home/rh8588/Dokumente/git/tvm/include/tvm/node/functor.h", line 95
InternalError: Check failed: (can_dispatch(n)) is false: NodeFunctor calls un-registered function on type tir.PrimFunc

I think it is related to tvm.ir.make_node since the forum post mentioned that there are some parts of the call_lowered are not there, like the attributes, and I most likely didn’t properly set them up correctly. What can I do from here on?

@junrushao Since you were involved with the other post I mentioned, do you maybe have an idea what I need to include here?