[issue] Failed to meta_schedule the third party op

version: 0.20.0

I am using tvm’s meta-schedule to optimize an op I have written, but failed. Here are the codes and the logs the script produced.

from tvm.script import tir as T
from tvm.script import relax as R
import tvm.script.parser as I
import tvm

M :T.int64 = 128
N: T.int64 = 16
period: T.int64 = 10


@T.prim_func
def my_ts_rank_window(arr: T.handle, ret: T.handle):
    T.func_attr({"global_symbol": "my_ts_rank_window", "tir.noalias": True})
    A = T.match_buffer(arr, (M, N), "float64")
    O = T.match_buffer(ret, (M, N), "float64")

    offset = period - 1
    step = T.float64(1) / T.Cast("float64", period)

#    with T.block("ts_rank_root"):
    if True:
        for i, j in T.grid(M, N - offset):
            with T.block("rank_compute"):
                vi = T.axis.spatial(M, i)
                vj = T.axis.spatial(N - offset, j)


                with T.If(T.isnan(A[vi, vj + offset])):
                    with T.Then():
                        O[vi, vj + offset] = T.float64("nan")
                    with T.Else():
                        for k in T.serial(0, period):
                            with T.block("rank_inner"):
                                with T.init():
                                    O[vi, vj + offset] = T.float64(0)
                                vk = T.axis.reduce(period, k)
                                if(A[vi, vj + vk] <= A[vi, vj + offset]):
                                    O[vi, vj + offset] =  O[vi, vj + offset] + step
        for i, j in T.grid(M, offset):
            with T.block("fill_prefix_nan"):
                vi = T.axis.spatial(M, i)
                vj = T.axis.spatial(offset, j)
                O[vi, vj] = T.float64("nan")

from tvm import meta_schedule as ms
mod = tvm.IRModule({"main": my_ts_rank_window})
mod = tvm.tir.transform.Simplify()(mod)
print("== specialize_func is", mod.show())
target = tvm.target.Target({
                           "kind": "cuda",
                           "max_threads_per_block": 1024,
                           "max_shared_memory_per_block": 49152,
                           "thread_warp_size": 32})
db = ms.tune_tir(
    mod=mod,
    target=target,
    max_trials_global=1024,
    work_dir="./ms_rank_window",
)

The log is herer

Traceback (most recent call last):
  File "/home/tvm/issue.py", line 54, in <module>
    db = ms.tune_tir(
        mod=mod,
    ...<2 lines>...
        work_dir="./ms_rank_window",
    )
  File "/home/tvm/repos/tvm/python/tvm/meta_schedule/tir_integration.py", line 146, in tune_tir
    return tune_tasks(
        tasks=tasks,
    ...<12 lines>...
        post_optimization=post_optimization,
    )
  File "/home/tvm/repos/tvm/python/tvm/meta_schedule/tune.py", line 122, in tune_tasks
    task_scheduler.tune(
    ~~~~~~~~~~~~~~~~~~~^
        tasks=tasks,
        ^^^^^^^^^^^^
    ...<8 lines>...
        cost_model=cost_model,
        ^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/tvm/repos/tvm/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
    _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
    ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        self,
        ^^^^^
    ...<9 lines>...
        cost_model,
        ^^^^^^^^^^^
    )
    ^
  File "tvm/_ffi/_cython/packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/packed_func.pxi", line 284, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/tvm/repos/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
  File "/home/tvm/repos/tvm/src/meta_schedule/task_scheduler/gradient_based.cc", line 54, in tvm::meta_schedule::GradientBasedNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
    TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/meta_schedule/task_scheduler/task_scheduler.cc", line 179, in tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
    for (int task_id; num_trials_already < max_trials_global && (task_id = NextTaskId()) != -1;) {
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/meta_schedule/task_scheduler/gradient_based.cc", line 72, in tvm::meta_schedule::GradientBasedNode::NextTaskId()
    this->JoinRunningTask(i);
            ^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/meta_schedule/task_scheduler/gradient_based.cc", line 126, in tvm::meta_schedule::GradientBasedNode::JoinRunningTask(int)
    Array<RunnerResult> results = TaskSchedulerNode::JoinRunningTask(task_id);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/meta_schedule/task_scheduler/task_scheduler.cc", line 232, in tvm::meta_schedule::TaskSchedulerNode::JoinRunningTask(int)
    callback->Apply(GetRef<TaskScheduler>(this), task_id, task->measure_candidates.value(),
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/meta_schedule/measure_callback/update_cost_model.cc", line 53, in tvm::meta_schedule::UpdateCostModelNode::Apply(tvm::meta_schedule::TaskScheduler const&, int, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&, tvm::runtime::Array<tvm::meta_schedule::BuilderResult, void> const&, tvm::runtime::Array<tvm::meta_schedule::RunnerResult, void> const&)
    cost_model->Update(task->ctx, pruned_candidate, pruned_runner_result);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/meta_schedule/cost_model/cost_model.cc", line 37, in tvm::meta_schedule::PyCostModelNode::Update(tvm::meta_schedule::TuneContext const&, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&, tvm::runtime::Array<tvm::meta_schedule::RunnerResult, void> const&)
    f_update(context, candidates, results);
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/home/tvm/repos/tvm/python/tvm/meta_schedule/utils.py", line 76, in method
    return getattr(inst, name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/python/tvm/meta_schedule/cost_model/xgb_model.py", line 495, in update
    new_features = [_feature(x) for x in self.extractor.extract_from(context, candidates)]
                                         ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/python/tvm/meta_schedule/feature_extractor/feature_extractor.py", line 58, in extract_from
    result = _ffi_api.FeatureExtractorExtractFrom(  # type: ignore # pylint: disable=no-member
        self, context, candidates
    )
  File "tvm/_ffi/_cython/packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/tvm/repos/tvm/src/meta_schedule/feature_extractor/per_store_feature.cc", line 1414, in tvm::meta_schedule::PerStoreFeatureNode::ExtractFrom(tvm::meta_schedule::TuneContext const&, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&)
    support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/repos/tvm/src/support/parallel_for.cc", line 139, in tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
    LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what();
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
tvm._ffi.base.TVMError: Traceback (most recent call last):
  1: tvm::meta_schedule::PerStoreFeatureNode::ExtractFrom(tvm::meta_schedule::TuneContext const&, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&)
        at /home/tvm/repos/tvm/src/meta_schedule/feature_extractor/per_store_feature.cc:1414
  0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
        at /home/tvm/repos/tvm/src/support/parallel_for.cc:139
  27: 0xffffffffffffffff
  26: clone
  25: start_thread
  24: execute_native_thread_routine
        at ../../../../../libstdc++-v3/src/c++11/thread.cc:104
  23: std::packaged_task<void (int)>::operator()(int)
  22: __pthread_once_slow
  21: _ZNKSt8functionIFSt10unique_ptrINSt13__future_base12_Result_baseENS
  20: operator()
        at /home/tvm/repos/tvm/src/support/parallel_for.cc:113
  19: _ZNSt17_Function_handlerIFviiEZN3tvm13meta_schedule19PerStoreFeatureNode11ExtractFromERKN
  18: _ZSt10__invoke_rIvRZN3tvm13meta_schedule19PerStoreFeatureNode11ExtractFromERKNS1_11TuneCo
  17: _ZSt13__invoke_implIvRZN3tvm13meta_schedule19PerStoreFeatureNode11ExtractFromERKNS1_11Tu
  16: tvm::meta_schedule::PerStoreFeatureNode::ExtractFrom(tvm::meta_schedule::TuneContext const&, tvm::runtime::Array<tvm::meta_schedule::MeasureCandidate, void> const&)::{lambda(int, int)#1}::operator()(int, int) const
        at /home/tvm/repos/tvm/src/meta_schedule/feature_extractor/per_store_feature.cc:1406
  15: tvm::meta_schedule::PerStoreFeatureNode::ExtractSingle(tvm::IRModule, bool, std::vector<std::vector<double, std::allocator<double> >, std::allocator<std::vector<double, std::allocator<double> > > >*)
        at /home/tvm/repos/tvm/src/meta_schedule/feature_extractor/per_store_feature.cc:1376
  14: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
        at /home/tvm/repos/tvm/src/tir/ir/transform.cc:121
  13: operator()
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:197
  12: tvm::tir::UnifyThreadBinding(tvm::tir::PrimFunc)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:189
  11: tvm::tir::ThreadBindingUnifier::Unify(tvm::tir::Stmt)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:44
  10: tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)
        at /home/tvm/repos/tvm/src/tir/ir/stmt_functor.cc:211
  9: tvm::runtime::Array<tvm::tir::Stmt, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::Stmt>::value, void>::type> tvm::tir::StmtMutator::Internal::MutateArray<tvm::tir::Stmt, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}>(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::Stmt>::value, void>::type> const&, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
        at /home/tvm/repos/tvm/src/tir/ir/stmt_functor.cc:184
  8: tvm::runtime::Array<tvm::tir::Stmt, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::Stmt>::value, void>::type> tvm::runtime::Array<tvm::tir::Stmt, void>::Map<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}) const
        at /home/tvm/repos/tvm/include/tvm/runtime/container/array.h:652
  7: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::ObjectPtr<tvm::runtime::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
        at /home/tvm/repos/tvm/include/tvm/runtime/container/array.h:823
  6: tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}::operator()(tvm::tir::Stmt const&) const
        at /home/tvm/repos/tvm/src/tir/ir/stmt_functor.cc:210
  5: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:64
  4: tvm::tir::Stmt tvm::tir::ThreadBindingUnifier::UnifyThreadBindingImpl<tvm::tir::ForNode>(tvm::tir::ForNode const*, tvm::tir::Var const&, tvm::tir::IterVar const&, tvm::Range const&)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:131
  3: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:64
  2: tvm::tir::Stmt tvm::tir::ThreadBindingUnifier::UnifyThreadBindingImpl<tvm::tir::ForNode>(tvm::tir::ForNode const*, tvm::tir::Var const&, tvm::tir::IterVar const&, tvm::Range const&)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:131
  1: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:64
  0: tvm::tir::Stmt tvm::tir::ThreadBindingUnifier::UnifyThreadBindingImpl<tvm::tir::ForNode>(tvm::tir::ForNode const*, tvm::tir::Var const&, tvm::tir::IterVar const&, tvm::Range const&)
        at /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:112
  File "/home/tvm/repos/tvm/src/support/parallel_for.cc", line 139
RuntimeError: parallel_for_dynamic error with [17:45:03] /home/tvm/repos/tvm/src/tir/transforms/unify_thread_binding.cc:112: Check failed: (ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) is false: ValueError: All loops that are bound to `blockIdx.x` should have the same extent. However, there are two loops with extent 14 and 1, which are not equal

I have found similar issues which are resolved in 0.17.0,but I am using the 0.20.0,so is it a regression?