Calling a PrimFunc from TVM IR

Hi @abel-bernabeu!

I’m not 100% sure if this will work, but you can try using the call_lowered Relay op, which I introduced in https://github.com/apache/tvm/pull/9312/.

This is our internal calling convention for lowered functions. In theory, you should be able to just stick the global var that represents the PrimFunc in as the function argument.

You’ll want to do something like this:

mainFnVar = relay.GlobalVar("main")
x = relay.var('a', shape=(n, m))
row = relay.var('row', shape=[1])
sb = relay.ScopeBuilder()
iters = relay.const([niters], dtype="int32")
b = sb.let('var', relay.const([[0,0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0],], dtype="int32"))
call_lowered = relay.op.get("call_lowered")
args = relay.Tuple([x, row, b])
tmp = sb.let("result", relay.Call(call_lowered, [opFnVar, args], attrs, span))
sb.ret(tmp)
mainFn = relay.Function([x, row], sb.get(), tvm.ir.TensorType([n,m]))
mod.update_func(mainFnVar, mainFn)

call_lowered hasn’t been set up to work with python (as you can see, I didn’t define attrs or span) so I’d suggest trying in c++. This example might be useful: https://github.com/apache/tvm/blob/main/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc#L119