Hi everyone,
I am trying to build schedule operations on a relay graph with TIR and the MS backend. Computation is originally defined in TE and then converted to a tir.Schedule. The code looks like this:
res = te.compute(
out_shape,
lambda i, j: te.sum(
inp[i, rk].astype(out_dtype) * wght[rk, j].astype(out_dtype),
axis=[rk],
),
name="res",
tag="dense",
)
func = te.create_prim_func([inp, wght, res])
sch = tvm.tir.Schedule(func)
schedule_matmul_gemmini(sch, do_tune=False)
I found this post that hinted at call_lowered
. So I added the required parts to my flow. This is the code to build the module:
# Put TIR schedule into a Relay function
mod = tvm.ir.IRModule()
mainFnVar = relay.GlobalVar("main")
call_lowered = tvm.relay.op.get("call_lowered")
ifm = relay.var("ifm", shape=inp_shape)
wgt = relay.var("wgt", shape=wght_shape)
args = relay.Tuple([ifm, wgt])
sb = relay.ScopeBuilder()
gemminiFnVar = relay.GlobalVar("gemminiMatmulTir")
attrs = tvm.ir.make_node(
"relay.attrs.CallLoweredAttrs", **{"metadata": {"relay_attrs": ""}}
)
tmp = sb.let("result", relay.Call(call_lowered, [gemminiFnVar, args], attrs=attrs))
sb.ret(tmp)
mainFn = relay.Function([ifm, wgt], sb.get(), tvm.ir.TensorType(out_shape, dtype="int8"))
mod.update_func(mainFnVar, mainFn)
mod.update_func(gemminiFnVar, sch.mod["main"])
with tvm.transform.PassContext(config=config, opt_level=3), ms.database.ScheduleFnDatabase(schedule_fn):
#module = tvm.build(sch.mod, [inp, wght, res], name="matmul_test", binds={inp: input_data, wght:weight_data, res: res_buf}, target=TARGET,)
module = tvm.relay.build(sch.mod, executor=EXECUTOR, runtime=RUNTIME, target=TARGET)
However, this throws the following error trace:
tvm.error.InternalError: Traceback (most recent call last):
13: tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
12: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
11: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule)
10: tvm::transform::Pass::operator()(tvm::IRModule) const
9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
6: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::ToBasicBlockNormalForm()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::ToBasicBlockNormalForm()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
4: tvm::relay::ToBasicBlockNormalForm(tvm::IRModule const&)
3: tvm::relay::FreeVars(tvm::RelayExpr const&)
2: tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)
1: tvm::relay::MixedModeVisitor::VisitLeaf(tvm::RelayExpr const&)
0: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
File "/home/rh8588/Dokumente/git/tvm/include/tvm/node/functor.h", line 95
InternalError: Check failed: (can_dispatch(n)) is false: NodeFunctor calls un-registered function on type tir.PrimFunc
I think it is related to tvm.ir.make_node
since the forum post mentioned that there are some parts of the call_lowered
are not there, like the attributes, and I most likely didn’t properly set them up correctly. What can I do from here on?