Crash during autoscheduling

Hi,

I’m trying to use the autoscheduler for a simple operator as shown below.

import os
import numpy as np
import tvm
from tvm import te, auto_scheduler, topi

@auto_scheduler.register_workload
def gemm(B, N, L, H):
    A = te.placeholder((B, N, L, H), name='A')
    W = te.placeholder((N * H, N * H), name='W')

    k = te.reduce_axis((0, N * H), name = 'k')
    O = te.compute((B, L, N * H), lambda b, l, o: te.sum(A[b, te.floordiv(k, H), l, te.floormod(k, H)] *
                                                         W[o, k], axis = k), name = 'O')

    return [A, W, O]

target = tvm.target.Target("cuda")

B = 32
N = 8
L = 512
H = 64
task = auto_scheduler.SearchTask(func=gemm, args=(B, N, L, H), target=target)

print("Computational DAG:")
print(task.compute_dag)

log_file = "gemm_ff1.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

del measure_ctx

This, however, leads to a crash (backtrace below) when the search is completed at which point the process hangs and I need to interrupt it to quit. Am I doing something wrong?

Computational DAG:
A = PLACEHOLDER [32, 8, 512, 64]
W = PLACEHOLDER [512, 512]
O(b, l, o) += (A[b, floordiv(k, 64), l, floormod(k, 64)]*W[o, k])

Get devices for measurement successfully!

    sch, args = task.apply_best(log_file)
  File "/home/ppf/rnn_compilers/ragged_tensors/new_tvm/python/tvm/auto_scheduler/search_task.py", line 526, in apply_best
    inp.state, layout_rewrite_option or self.layout_rewrite_option
  File "/home/ppf/rnn_compilers/ragged_tensors/new_tvm/python/tvm/auto_scheduler/compute_dag.py", line 154, in apply_steps_from_state
    return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  6: TVMFuncCall
  5: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::runtime::ObjectRef, void> (tvm::auto_scheduler::ComputeDAG const&, tvm::auto_scheduler::State const&, int)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::auto_scheduler::ComputeDAG const&, tvm::auto_scheduler::State const&, int)#5}>(tvm::auto_scheduler::{lambda(tvm::auto_scheduler::ComputeDAG const&, tvm::auto_scheduler::State const&, int)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  4: tvm::auto_scheduler::ComputeDAG::ApplySteps(tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&, tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::auto_scheduler::LayoutRewriteOption) const
  3: tvm::auto_scheduler::StepApplyToSchedule(tvm::auto_scheduler::Step const&, tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::te::Schedule*, tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&)
  2: _ZNK3tvm14auto_scheduler13SplitStepNode15ApplyToScheduleEPNS_7runtime5ArrayINS_2te5StageE
  1: tvm::auto_scheduler::ApplySplitToSchedule(tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, int, int, tvm::runtime::Array<tvm::runtime::Optional<tvm::Integer>, void> const&, bool)
  0: tvm::runtime::Array<tvm::tir::IterVar, void>::operator[](long) const
  File "../include/tvm/runtime/container/array.h", line 393
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (0 <= i && i < p->size_) is false: IndexError: indexing 20 on an array of size 20