Question on TensorIR's support of multi-axis parallelization

Hi community, since the TensorIR is to come out, I am excited experimenting with its scheduling feature. However, when I tried to parallel a simple matrix-scalar multiplication along two axes, the compilation crashed. Here is the complete code snippet:

import tvm
from tvm.script import tir as T

@tvm.script.ir_module
class MatMul:
    @T.prim_func
    def main(a: T.handle, b: T.handle) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, (4, 4))
        B = T.match_buffer(b, (4, 4))
        for i, j in T.grid(4, 4):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0


sch = tvm.tir.Schedule(MatMul)
block_b = sch.get_block("B")
i, j = sch.get_loops(block_b)
sch.parallel(i)
sch.parallel(j)

tvm.build(sch.mod)
# libc++abi: terminating with uncaught exception of type tvm::runtime::InternalError
# [1]    19814 abort      python src/main.py

Is it violating some preconditions or simply a problem of TVM itself?

On my side TensorIR is able to print correct error message:

Check failed: (!parallel_env_.in_parallel_loop) is false: Nested parallel loop is not supported by threadpool, try fuse them instead

Correct, nested parallelized loops are not allowed anywhere in either TE or TensorIR scheduling, because the threadpool doesn’t allow nested thread launching

1 Like

Thanks for your reply! Got the idea.

You can fuse the two loops and parallelize the fused loop like:

ij = sch.fuse([i, j])
sch.parallel(ij)

For devices like GPU, you can bind i and j to blockIdx.x/y/z and threadIdx.x/y/z.

btw, the error message looks interesting to me. i suppose this will be captured at the ffi boundary.

Yep, it’s strange! I find the problem still exists with the latest TVM commit. I ran TVM on macOS 12.0.1 and it’s built with LLVM 12. LLDB gives me this message:

* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGABRT
  * frame #0: 0x00007ff811569112 libsystem_kernel.dylib`__pthread_kill + 10
    frame #1: 0x00007ff81159f233 libsystem_pthread.dylib`pthread_kill + 263
    frame #2: 0x00007ff8114ebd10 libsystem_c.dylib`abort + 123
    frame #3: 0x00007ff81155c0b2 libc++abi.dylib`abort_message + 241
    frame #4: 0x00007ff81154d1fd libc++abi.dylib`demangling_terminate_handler() + 266
    frame #5: 0x00007ff81144a511 libobjc.A.dylib`_objc_terminate() + 104
    frame #6: 0x00007ff81155b4d7 libc++abi.dylib`std::__terminate(void (*)()) + 8
    frame #7: 0x00007ff81155b488 libc++abi.dylib`std::terminate() + 56
    frame #8: 0x000000015c8352ab libtvm.dylib`__clang_call_terminate + 11
    frame #9: 0x000000015dcd0775 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 149
    frame #10: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #11: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #12: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #13: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #14: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #15: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #16: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #17: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #18: 0x000000015dcd0d85 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*) + 181
    frame #19: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #20: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #21: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #22: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #23: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #24: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #25: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #26: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #27: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #28: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #29: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #30: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #31: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #32: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #33: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #34: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #35: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #36: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #37: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #38: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #39: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #40: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #41: 0x000000015dcd0d85 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*) + 181
    frame #42: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #43: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #44: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #45: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #46: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #47: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #48: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #49: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #50: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #51: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #52: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #53: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #54: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #55: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #56: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #57: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #58: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #59: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #60: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #61: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #62: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #63: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #64: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #65: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #66: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #67: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #68: 0x000000015dccfc9b libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*) + 571
    frame #69: 0x000000015dcadbbb libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*) + 859
    frame #70: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #71: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #72: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #73: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #74: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #75: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #76: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #77: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #78: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #79: 0x000000015dccfc9b libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*) + 571
    frame #80: 0x000000015dcadbbb libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*) + 859
    frame #81: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #82: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #83: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #84: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #85: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #86: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #87: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #88: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #89: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #90: 0x000000015dcd09bc libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*) + 428
    frame #91: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #92: 0x000000015dcd0730 libtvm.dylib`tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
    frame #93: 0x000000015dcad6ec libtvm.dylib`tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 748
    frame #94: 0x000000015c886f8f libtvm.dylib`tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 303
    frame #95: 0x000000015dcc0988 libtvm.dylib`tvm::codegen::CodeGenLLVM::AddFunctionInternal(tvm::tir::PrimFunc const&, bool) + 3176
    frame #96: 0x000000015dca475a libtvm.dylib`tvm::codegen::CodeGenCPU::AddFunction(tvm::tir::PrimFunc const&) + 42
    frame #97: 0x000000015dcbbeff libtvm.dylib`void tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<std::__1::__wrap_iter<tvm::tir::PrimFunc*>, void tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<std::__1::__wrap_iter<tvm::tir::PrimFunc*> >(std::__1::__wrap_iter<tvm::tir::PrimFunc*>, std::__1::__wrap_iter<tvm::tir::PrimFunc*>)::'lambda'(std::__1::__wrap_iter<tvm::tir::PrimFunc*>)>(std::__1::__wrap_iter<tvm::tir::PrimFunc*>, std::__1::__wrap_iter<tvm::tir::PrimFunc*>, void tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<std::__1::__wrap_iter<tvm::tir::PrimFunc*> >(std::__1::__wrap_iter<tvm::tir::PrimFunc*>, std::__1::__wrap_iter<tvm::tir::PrimFunc*>)::'lambda'(std::__1::__wrap_iter<tvm::tir::PrimFunc*>)) + 495
    frame #98: 0x000000015dcff439 libtvm.dylib`tvm::codegen::LLVMModuleNode::Init(tvm::IRModule const&, tvm::Target const&) + 7177
    frame #99: 0x000000015dcfd5d9 libtvm.dylib`std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::codegen::$_0>(tvm::codegen::$_0, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::codegen::$_0>(tvm::codegen::$_0, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 825
    frame #100: 0x000000015d2e27b4 libtvm.dylib`tvm::codegen::Build(tvm::IRModule, tvm::Target) + 1700
    frame #101: 0x000000015caba6db libtvm.dylib`tvm::PreProcessModuleForBuild(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&) + 2619
    frame #102: 0x000000015cac8a42 libtvm.dylib`std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::$_6>(tvm::$_6, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::$_6>(tvm::$_6, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 770
    frame #103: 0x000000015dd26706 libtvm.dylib`TVMFuncCall + 70
    frame #104: 0x00000001006d8ead libffi.7.dylib`ffi_call_unix64 + 85
    frame #105: 0x00000001006d8579 libffi.7.dylib`ffi_call_int + 761
    frame #106: 0x0000000101b314d8 _ctypes.cpython-38-darwin.so`_ctypes_callproc + 744
    frame #107: 0x0000000101b2b523 _ctypes.cpython-38-darwin.so`PyCFuncPtr_call + 275
    frame #108: 0x000000010002e21d python`_PyObject_MakeTpCall + 173
    frame #109: 0x0000000100179ecb python`call_function + 315
    frame #110: 0x00000001001763e2 python`_PyEval_EvalFrameDefault + 43522
    frame #111: 0x0000000100169dc4 python`_PyEval_EvalCodeWithName + 564
    frame #112: 0x000000010002f8ca python`_PyFunction_Vectorcall + 426
    frame #113: 0x000000010002df5d python`_PyObject_FastCallDict + 93
    frame #114: 0x00000001000b399e python`slot_tp_call + 174
    frame #115: 0x000000010002e21d python`_PyObject_MakeTpCall + 173
    frame #116: 0x0000000100179ecb python`call_function + 315
    frame #117: 0x00000001001763e2 python`_PyEval_EvalFrameDefault + 43522
    frame #118: 0x0000000100169dc4 python`_PyEval_EvalCodeWithName + 564
    frame #119: 0x000000010002f8ca python`_PyFunction_Vectorcall + 426
    frame #120: 0x0000000100179e32 python`call_function + 162
    frame #121: 0x00000001001763e2 python`_PyEval_EvalFrameDefault + 43522
    frame #122: 0x0000000100169dc4 python`_PyEval_EvalCodeWithName + 564
    frame #123: 0x00000001001f128f python`pyrun_file + 271
    frame #124: 0x00000001001f0a8f python`pyrun_simple_file + 463
    frame #125: 0x00000001001f086e python`PyRun_SimpleFileExFlags + 110
    frame #126: 0x000000010021a35f python`pymain_run_file + 463
    frame #127: 0x0000000100219876 python`pymain_run_python + 534
    frame #128: 0x0000000100219605 python`Py_RunMain + 37
    frame #129: 0x000000010021ae21 python`pymain_main + 49
    frame #130: 0x0000000100001678 python`main + 56
    frame #131: 0x00000001004704fe dyld`start + 462

Weird, I’m unable to reproduce this. The tvm::runtime::InternalError defined here: https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h#L227

Would you like to check if it is the case where any InternalError is not caught, or it’s something special with our codegen?

I am able to reproduce this. Codegen correctly raised an error and it was caught by FFI boundary, but some error happened inside malloc during the post-processing of error message.

I am not sure what was wrong here. My env:

ldd libtvm_runtime.so
	linux-vdso.so.1 (0x00007ffe87169000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f11d84f0000)
	libcudart.so.11.0 => /usr/local/cuda/lib64/libcudart.so.11.0 (0x00007f11d824e000)
	libcuda.so.1 => /lib/x86_64-linux-gnu/libcuda.so.1 (0x00007f11d6a84000)
	libstdc++.so.6 => /lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f11d686b000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f11d671d000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f11d6700000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f11d66de000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f11d64f2000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f11d8b09000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007f11d64e7000)
1 Like

Memory is corrupted if exception is raised in Codegen llvm. This line changes destruction order

here analyzer_ becomes a local variable, if any exceptions are thrown during recursion visit, it will be destructed earlier than expected (there are some other objects whose dtor depends on analyzer_)

cc @kparzysz

3 Likes

Do you know what objects specifically?

it is ContraintContext https://github.com/apache/tvm/blob/main/src/target/llvm/codegen_llvm.cc#L1496 which is created before the new analyzer is created

I guess the simplest thing to do would be changing

VisitStmt(op->body);

to

try {
  VisitStmt(op->body);
} catch (...) {
  std::swap(analyzer_, new_analyzer);
  throw;
}

It’s not super elegant, but it should fix the problem. There were several places where this swap takes place, I think all of them will need to be updated.

2 Likes

That makes a lot of sense. Thanks @vinx13 @kparzysz for digging into this!

Hi, sorry for adding to an old post. I recently encountered the same error while playing with TIR scheduling primitives. The discussion in the thread makes sense to me and I understand the fundamental difficulty of having nested parallelization.

One thing that confuses me is that according to the TIR paper (ASPLOS 2023) TIR has a correctness guarantee on schedule primitives which means if a schedule can be applied without error, it’s valid. In this case, nested parallel is not allowed but the error is thrown all the way until codegen, rather than TIR scheduling time. Adding a check here seems to better cope with the TIR correctness claims. Would be happy to send a PR on this. cc @junrushao @yzh119