Does TIR support cross-function call?

Yes, TIR do support it. But you have to change your code a little bit.

A = te.var('A')
B = te.var('B')

callee = tir.PrimFunc([A, B], tir.Evaluate(tir.Add(A, B)))
callee = callee.with_attr('global_symbol', 'callee')

main = tir.PrimFunc([A, B], tir.Evaluate(tir.Call('int32', callee, [A, B])))
# OR main = tir.PrimFunc([A, B], tir.Evaluate(tir.Call('int32', GlobalVar('callee'), [A, B])))
main = main.with_attr("global_symbol", "main")
main = main.with_attr("tir.is_entry_func", True)
main = main.with_attr("tir.noalias", True)

mod = tvm.IRModule({"main": main, "callee": callee})
tvm.build(mod)

After that, you need to modify the pass split_host_device.cc which lost the is_entry_func attribute. Then add your own codegen to check the global function or other functions.

1 Like