Hi guys, I am new to TVM, and what I am doing is using mlc-llm to compile an LLM, which includes adaptive average pooling 2d operation. In my model, the input shape of this op is NOT an integer multiple of output shape. I tried to use relax.op.nn.adaptive_avg_pool2d, but what I got error as below:
ValueError: Traceback (most recent call last):
64: _ZN3tvm7runtime13PackedFuncObj
63: tvm::runtime::TypedPackedFunc<tvm::tir::Schedule (tvm::IRModule, long, int, int, bool)>::AssignTypedLambda<tvm::tir::__mk_TVM16::{lambda(tvm::IRModule, long, int, int, bool)#1}>(tvm::tir::__mk_TVM16::{lambda(tvm::IRModule, long, int, int, bool)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
62: tvm::tir::Schedule::Traced(tvm::IRModule, long, int, tvm::tir::ScheduleErrorRenderLevel, bool)
61: tvm::tir::ScheduleState::ScheduleState(tvm::IRModule, int, bool)
60: tvm::tir::VerifyWellFormed(tvm::tir::PrimFunc const&, bool)
59: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::PrimFunc const&, tvm::ObjectPath)
58: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
57: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
56: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#18}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
55: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockRealizeNode const*, tvm::ObjectPath)
54: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
53: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
52: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#17}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
51: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockNode const*, tvm::ObjectPath)
50: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
49: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
48: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
47: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
46: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
45: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
44: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
43: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
42: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
41: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
40: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
39: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
38: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
37: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
36: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
35: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
34: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
33: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
32: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
31: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
30: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
29: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
28: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
27: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
26: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
25: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#8}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
24: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
23: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
22: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
21: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#27}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
20: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::SelectNode const*, tvm::ObjectPath)
19: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
18: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
17: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#16}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
16: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
15: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
14: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
13: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
12: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
11: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
10: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
9: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#7}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
8: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
7: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
6: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
5: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#9}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
4: tvm::tir::TIRVisitorWithPath::VisitExpr_(tvm::tir::EQNode const*, tvm::ObjectPath)
3: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
2: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath) const
1: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#1}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
0: tvm::tir::UndefinedVarVerifier::VisitExpr_(tvm::tir::VarNode const*, tvm::ObjectPath)
File "/lpai/volumes/jointmodel/jpf/specdecoding/mlc-llm/3rdparty/tvm/src/tir/analysis/verify_well_formed.cc", line 100
ValueError: Invalid use of undefined variable ax2 at <root>.body.block.body.seq[0].body.body.body.body.extent.a.condition.a.a.a.a.
I tried to print out the TIR generated, it looks like: @T.prim_func(private=True) def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), “float16”), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), “float16”)): T.func_attr({“op_pattern”: 4, “tir.noalias”: T.bool(True)}) # with T.block(“root”): adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), “float16”) for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)): ax2_1 = T.int64() ax3_1 = T.int64() with T.block(“adaptive_pool_sum”): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap(“SSSSRR”, [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0) adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): with T.block(“adaptive_pool_avg”): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap(“SSSS”, [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({“schedule_rule”: “meta_schedule.adaptive_pool_avg”}) adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast(“float16”, T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast(“int64”, v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast(“int64”, v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast(“int64”, v_ax2) * T.int64(16) // T.int64(12)) * T.Cast(“float16”, T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast(“int64”, v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast(“int64”, v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast(“int64”, v_ax3) * T.int64(40) // T.int64(30)))
@T.prim_func(private=True)
def adaptive_avg_pool2d1(reshape58: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
# with T.block("root"):
adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
ax2_1 = T.int64()
ax3_1 = T.int64()
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
T.reads(reshape58[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape58[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v_ax2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v_ax3) * T.int64(40) // T.int64(30)))
I found variable ax2_1 and ax3_1 are used before they are defined. I am not sure if this is what causes the error, but it looks weird to me.
And I compiled with input to this op an integer multiple of output in shape successfully, the TIR is like:
@T.prim_func(private=True)
def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(8), T.int64(20)), "float16")):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
# with T.block("root"):
adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(8), T.int64(20)), "float16")
for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(2), T.int64(1024), T.int64(8), T.int64(20), T.int64(2), T.int64(2)):
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_rv0, v_ax3 * T.int64(2) + v_rv1])
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_rv0, v_ax3 * T.int64(2) + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(8), T.int64(20)):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.float16(2) * T.float16(2))
Any advice is appreciated!! Thanks.