TensorIR call prim func

Hi, all, How do I call a prim_func from another prim_func in TensorIR?

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        # We exchange data between function by handles, which are similar to pointer.
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # Create buffer from handles.
        A = T.match_buffer(a, (8,), dtype="float32")
        B = T.match_buffer(b, (8,), dtype="float32")
        for i in range(8):
            # A block is an abstraction for computation.
            with T.block("B"):
                # Define a spatial block iterator and bind it to value i.
                vi = T.axis.spatial(8, i)
                B[vi] = A[vi] + 1.0
        T.evaluate(T.call_extern("plus_one", B.data, dtype="float32"))
    
    @T.prim_func
    def plus_one(a: T.handle):
        T.func_attr({"global_symbol": "plus_one", "tir.noalias": True,})
        A = T.match_buffer(a, (8, 8), dtype="float32")
        for i, j in T.grid(8, 8):
            with T.block("plus"):
                vi, vj = T.axis.remap("SS", [i, j])
                A[vi, vj] = A[vi, vj] + 1.0

I tried to build the above module with “llvm” target, however, I got an error:

at /project/ai/scratch01/maxxu01/workspaces/AIPU_tvm_design/src/target/llvm/codegen_llvm.cc:137
  File "/project/ai/scratch01/maxxu01/workspaces/AIPU_tvm_design/src/target/llvm/codegen_llvm.cc", line 137
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr) is false: Function "plus_one" already exist in module

But I saw that many codes in unittest, such as “test_tir_usmp_utils.py”, call another func with the above method, though they do not build the module.

Any ideas? THX.

refer to :slight_smile: Calling a PrimFunc from TVM IR

Thanks for your reply.

But can’t it work like TE? Is Relay necessary in this case?

Hi, as a summary, TIR → TIR invocations are not fully supported yet. In your example, T.call_extern("plus_one", B.data, dtype="float32") do not means invoke the tir func within module generally. It should be just written as plus_one(B.data, dtype="float32"), which is not supported however.

1 Like

Here is my solution, save the global vars and its corresponding prim_funcs into a map at first. And then inline these functions together when calling another prim_func. Function calls will be pretty complicated when analyzing the buffer.

I agree that this is very complicated. I am trying to wrap these prim_funcs with relay.Function, but I haven’t done yet. I am not familier with Relay DSL, especially the Function call (e.g., the args, and the ret). Again, it is very complicated to me.

What I want to do is to implement these long-tail operators using tir.script rather than IRbuilder. Because IRbuilder is hard to differentiate python ast and tir nodes which bother me a lot during debugging. To achieve better perfomance, I decide to inline these functions together. You can use relay as well, but fusing these operators will still be hard.

When the prim_func is simple, inlining will be OK. In my case, two prim_funcs are both complicated.

TIR doesn’t have a construct for functional call, so this is a hard problem. Can you tell me your use case (the need to call a complicated func from another)?

Would you guys like to try T.call_packed instead of T.call_extern?

In our case, the callee is a common subroutine that we want to build as a separate function. It will be used to precompute some indices for other functions.

I am wondering is it sufficient to just provide a T.call intrinsic which will insert a tir.Call stmt in tir/intrinsic.py in this case?

It seems that there is no T.call_packed and call_packed only exists in the original TIR IRBuilder infra.

It seems that the cross call of primfunc is not supported yet, hope it will be supported in the future

Can you show your code?