Hi, all, How do I call a prim_func from another prim_func in TensorIR?
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
# A block is an abstraction for computation.
with T.block("B"):
# Define a spatial block iterator and bind it to value i.
vi = T.axis.spatial(8, i)
B[vi] = A[vi] + 1.0
T.evaluate(T.call_extern("plus_one", B.data, dtype="float32"))
@T.prim_func
def plus_one(a: T.handle):
T.func_attr({"global_symbol": "plus_one", "tir.noalias": True,})
A = T.match_buffer(a, (8, 8), dtype="float32")
for i, j in T.grid(8, 8):
with T.block("plus"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = A[vi, vj] + 1.0
I tried to build the above module with “llvm” target, however, I got an error:
at /project/ai/scratch01/maxxu01/workspaces/AIPU_tvm_design/src/target/llvm/codegen_llvm.cc:137
File "/project/ai/scratch01/maxxu01/workspaces/AIPU_tvm_design/src/target/llvm/codegen_llvm.cc", line 137
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr) is false: Function "plus_one" already exist in module
But I saw that many codes in unittest, such as “test_tir_usmp_utils.py”, call another func with the above method, though they do not build the module.
Any ideas? THX.