Does TIR support cross-function call?

Hi all, I’m a new user of TVM and am learning TIR. Now I want to define my own operator using TIR interface, which means I’d create a PrimFunc at the top level, initialize an IRModule from it and build this IRModule.

The problem is that I also want to define multiple PrimFuncs called by the top-level PrimFunc, but I could not. I know I may simply put all things into one PrimFunc, but I just want to know whether cross-function call is possible.

The following minimal working example demonstrate my purpose:

A = te.var('A')
B = te.var('B')

callee = tir.PrimFunc([A, B], tir.Evaluate(tir.Add(A, B)))
callee = callee.with_attr('global_symbol', 'callee')

main = tir.PrimFunc([A, B], tir.Evaluate(tir.Call('int32', callee, [A, B])))
# OR main = tir.PrimFunc([A, B], tir.Evaluate(tir.Call('int32', GlobalVar('callee'), [A, B])))
main = main.with_attr("global_symbol", "main")
main = main.with_attr("tir.noalias", True)

mod = tvm.IRModule({"main": main, "callee": callee})
tvm.build(mod)

Thanks @UniverseFly !

Yes this is something that we would love to have and the IR spec already support it(by padding in a global var. But right now we indeed do not yet have a compiler support for it.

1 Like

Yes, TIR do support it. But you have to change your code a little bit.

A = te.var('A')
B = te.var('B')

callee = tir.PrimFunc([A, B], tir.Evaluate(tir.Add(A, B)))
callee = callee.with_attr('global_symbol', 'callee')

main = tir.PrimFunc([A, B], tir.Evaluate(tir.Call('int32', callee, [A, B])))
# OR main = tir.PrimFunc([A, B], tir.Evaluate(tir.Call('int32', GlobalVar('callee'), [A, B])))
main = main.with_attr("global_symbol", "main")
main = main.with_attr("tir.is_entry_func", True)
main = main.with_attr("tir.noalias", True)

mod = tvm.IRModule({"main": main, "callee": callee})
tvm.build(mod)

After that, you need to modify the pass split_host_device.cc which lost the is_entry_func attribute. Then add your own codegen to check the global function or other functions.

1 Like