Compilation error for adaptive_avg_pool2d relax op in mlc-llm

Hi guys, I am new to TVM, and what I am doing is using mlc-llm to compile an LLM, which includes adaptive average pooling 2d operation. In my model, the input shape of this op is NOT an integer multiple of output shape. I tried to use relax.op.nn.adaptive_avg_pool2d, but what I got error as below:

ValueError: Traceback (most recent call last):
  64: _ZN3tvm7runtime13PackedFuncObj
  63: tvm::runtime::TypedPackedFunc<tvm::tir::Schedule (tvm::IRModule, long, int, int, bool)>::AssignTypedLambda<tvm::tir::__mk_TVM16::{lambda(tvm::IRModule, long, int, int, bool)#1}>(tvm::tir::__mk_TVM16::{lambda(tvm::IRModule, long, int, int, bool)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  62: tvm::tir::Schedule::Traced(tvm::IRModule, long, int, tvm::tir::ScheduleErrorRenderLevel, bool)
  61: tvm::tir::ScheduleState::ScheduleState(tvm::IRModule, int, bool)
  60: tvm::tir::VerifyWellFormed(tvm::tir::PrimFunc const&, bool)
  59: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::PrimFunc const&, tvm::ObjectPath)
  58: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  57: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  56: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#18}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  55: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockRealizeNode const*, tvm::ObjectPath)
  54: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  53: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  52: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#17}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  51: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockNode const*, tvm::ObjectPath)
  50: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  49: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  48: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  47: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
  46: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  45: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  44: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  43: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  42: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  41: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  40: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  39: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  38: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  37: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  36: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  35: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  34: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  33: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  32: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  31: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  30: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  29: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  28: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  27: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  26: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  25: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#8}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  24: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
  23: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  22: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  21: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#27}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  20: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::SelectNode const*, tvm::ObjectPath)
  19: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  18: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  17: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#16}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  16: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
  15: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  14: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  13: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  12: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
  11: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  10: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  9: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#7}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  8: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
  7: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  6: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  5: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#9}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  4: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
  3: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  2: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
  1: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#1}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  0: tvm::tir::UndefinedVarVerifier::VisitExpr_(tvm::tir::VarNode const*, tvm::ObjectPath)
  File "/lpai/volumes/jointmodel/jpf/specdecoding/mlc-llm/3rdparty/tvm/src/tir/analysis/verify_well_formed.cc", line 100
ValueError: Invalid use of undefined variable ax2 at <root>.body.block.body.seq[0].body.body.body.body.extent.a.condition.a.a.a.a.

I tried to print out the TIR generated, it looks like: @T.prim_func(private=True) def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), “float16”), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), “float16”)): T.func_attr({“op_pattern”: 4, “tir.noalias”: T.bool(True)}) # with T.block(“root”): adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), “float16”) for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)): ax2_1 = T.int64() ax3_1 = T.int64() with T.block(“adaptive_pool_sum”): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap(“SSSSRR”, [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0) adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): with T.block(“adaptive_pool_avg”): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap(“SSSS”, [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({“schedule_rule”: “meta_schedule.adaptive_pool_avg”}) adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast(“float16”, T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast(“int64”, v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast(“int64”, v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast(“int64”, v_ax2) * T.int64(16) // T.int64(12)) * T.Cast(“float16”, T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast(“int64”, v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast(“int64”, v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast(“int64”, v_ax3) * T.int64(40) // T.int64(30)))

    @T.prim_func(private=True)
    def adaptive_avg_pool2d1(reshape58: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
            ax2_1 = T.int64()
            ax3_1 = T.int64()
            with T.block("adaptive_pool_sum"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                T.reads(reshape58[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
                T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                with T.init():
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
                adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape58[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v_ax2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v_ax3) * T.int64(40) // T.int64(30)))

I found variable ax2_1 and ax3_1 are used before they are defined. I am not sure if this is what causes the error, but it looks weird to me.

And I compiled with input to this op an integer multiple of output in shape successfully, the TIR is like:

@T.prim_func(private=True)
def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(8), T.int64(20)), "float16")):
    T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
    # with T.block("root"):
    adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(8), T.int64(20)), "float16")
    for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(2), T.int64(1024), T.int64(8), T.int64(20), T.int64(2), T.int64(2)):
        with T.block("adaptive_pool_sum"):
            v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
            T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_rv0, v_ax3 * T.int64(2) + v_rv1])
            T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
            with T.init():
                adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
            adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_rv0, v_ax3 * T.int64(2) + v_rv1]
    for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(8), T.int64(20)):
        with T.block("adaptive_pool_avg"):
            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
            T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
            T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
            T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
            adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.float16(2) * T.float16(2))

Any advice is appreciated!! Thanks.

Hi, thanks for the finding! Could you reproduce the error using low-level TE interface?

x = te.placeholder([1, 1024, 16, 40], "float32", "x")
y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg")
f = te.create_prim_func([x, y])
print(f)
tvm.build(f, target="llvm")

If so, it maybe an issue alike Te.create_prim_func vs tvm.lower. A fix is send to the repo. Please refer to

Thanks for your quick answer and fix!

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/workspace/tvm/python/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/workspace/tvm/src/driver/driver_api.cc", line 532, in operator()
    return TIRToRuntime(inputs_arg, host_target);
  File "/workspace/tvm/src/driver/driver_api.cc", line 493, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
    auto pair = SplitMixedModule(ir_module, target, target_host);
  File "/workspace/tvm/src/driver/driver_api.cc", line 419, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
    mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
  File "/workspace/tvm/src/driver/driver_api.cc", line 290, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
    mod = seq(std::move(mod));
  File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 435, in operator()
    func = MakePackedAPI(std::move(func));
  File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 398, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
    ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
tvm.error.InternalError: Traceback (most recent call last):
  5: operator()
        at /workspace/tvm/src/driver/driver_api.cc:532
  4: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
        at /workspace/tvm/src/driver/driver_api.cc:493
  3: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
        at /workspace/tvm/src/driver/driver_api.cc:419
  2: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
        at /workspace/tvm/src/driver/driver_api.cc:290
  1: operator()
        at /workspace/tvm/src/tir/transforms/make_packed_api.cc:435
  0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
        at /workspace/tvm/src/tir/transforms/make_packed_api.cc:398
  File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 398
InternalError: Check failed: undefined.size() == 0 (2 vs. 0) : In PrimFunc default_function variables [ax2, ax3] are used, but are not passed in as API arguments

I also applied your fix to tvm, and when I re-compile the model, new error appeared as follows:

build
    relax.build(
  File "/workspace/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
  File "/workspace/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "tvm/_ffi/_cython/core.cpp", line 7494, in __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback
    TVMAPISetLastPythonError(((void *)__pyx_v_err));
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/mnt/volumes/jointmodel/songtianchen/mlc-llm-dev-eagle-lpai/python/mlc_llm/compiler_pass/pipeline.py", line 181, in _pipeline
    mod = seq(mod)
  File "/workspace/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "tvm/_ffi/_cython/core.cpp", line 7494, in __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback
    TVMAPISetLastPythonError(((void *)__pyx_v_err));
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/workspace/tvm/python/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/workspace/tvm/python/tvm/dlight/base/transform.py", line 71, in transform_module
    sch = _apply_rules(func, target, self.rules, tunable=False)
  File "/workspace/tvm/python/tvm/dlight/base/transform.py", line 87, in _apply_rules
    space = rule.apply(func, target, tunable)
  File "/workspace/tvm/python/tvm/dlight/gpu/general_reduction.py", line 114, in apply
    sch.compute_at(block, bx, preserve_unit_loops=True)
  File "/workspace/tvm/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
    return func(*args, **kwargs)
  File "/workspace/tvm/python/tvm/tir/schedule/schedule.py", line 2111, in compute_at
    _ffi_api.ScheduleComputeAt(  # type: ignore # pylint: disable=no-member
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
  1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
        at /workspace/tvm/src/tir/schedule/traced_schedule.cc:489
  0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
        at /workspace/tvm/src/tir/schedule/concrete_schedule.cc:790
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(var_reshape240: T.handle, var_adaptive_pool_avg: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        reshape240 = T.match_buffer(var_reshape240, (T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16")
        adaptive_pool_avg = T.match_buffer(var_adaptive_pool_avg, (T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        # tir.Block#0
        with T.block("root"):
        ^^^^^^^^^^^^^^^^^^^^^
            T.reads()
            ^^^^^^^^^
            T.writes()
            ^^^^^^^^^^
            adaptive_pool_sum_shared = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16", scope="shared")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0 in range(T.int64(2)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax1 in range(T.int64(1024)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax2 in range(T.int64(12)):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax3 in range(T.int64(30)):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            for ax4 in range(T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12)):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                ax2_1 = T.int64()
                                ^^^^^^^^^^^^^^^^^
                                for ax5 in range(T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    ax3_1 = T.int64()
                                    ^^^^^^^^^^^^^^^^^
                                    with T.block("adaptive_pool_sum"):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v0 = T.axis.spatial(T.int64(2), ax0)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v1 = T.axis.spatial(T.int64(1024), ax1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v2 = T.axis.spatial(T.int64(12), ax2)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v3 = T.axis.spatial(T.int64(30), ax3)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v4 = T.axis.reduce(T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), ax4)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v5 = T.axis.reduce(T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30), ax5)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.writes(adaptive_pool_sum_shared[v0, v1, v2, v3])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        with T.init():
                                        ^^^^^^^^^^^^^^
                                            adaptive_pool_sum_shared[v0, v1, v2, v3] = T.float16(0)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        adaptive_pool_sum_shared[v0, v1, v2, v3] = adaptive_pool_sum_shared[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0_ax1_ax2_ax3_fused in T.thread_binding(T.int64(737280), thread="blockIdx.x"):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax4 in range(T.int64(1)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax5_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax5_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            with T.block("adaptive_pool_avg"):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v0 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_fused // T.int64(368640))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v1 = T.axis.spatial(T.int64(1024), ax0_ax1_ax2_ax3_fused % T.int64(368640) // T.int64(360))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v2 = T.axis.spatial(T.int64(12), ax0_ax1_ax2_ax3_fused % T.int64(360) // T.int64(30))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v3 = T.axis.spatial(T.int64(30), ax0_ax1_ax2_ax3_fused % T.int64(30))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v4 = T.axis.spatial(T.int64(1), ax4)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v5 = T.axis.spatial(T.int64(1), ax5_0 * T.int64(256) + ax5_1)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.where(ax5_0 * T.int64(256) + ax5_1 < T.int64(1))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.reads(adaptive_pool_sum_shared[v0, v1, v2, v3])
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.writes(adaptive_pool_avg[v0, v1, v2, v3])
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum_shared[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v3) * T.int64(40) // T.int64(30)))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The scope tir.Block#0 is not a stage pipeline.
Definition of a scope that is a stage pipeline:
- The region cover property holds for every of its child blocks
- No write-after-read dependency or opaque dependency,
- only read-after-write and write-after-write are allowed
- All the statements in the scope are schedulable statements, i.e. Block and For

I also printed out the generated TIR, and I compared it with the code in the test code you provided, I found they are almost the same.

@T.prim_func(private=True)
    def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
                with T.block("adaptive_pool_sum"):
                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                    T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
                    T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                    with T.init():
                        adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v_ax2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v_ax3) * T.int64(40) // T.int64(30)))

Could you help look into further on this issue? Many thanks!