[Bug] [MetaSchedule] Occasional error tuning PyTorch model

I am attempting to tune the pretrained AlexNet model from torchvision using meta schedule. My minimal script (see below) is based off of one of the meta schedule unit tests (test_meta_schedule_tune_relay.py).

Most of the time, executing this script leads to the intended result: the network is successfully tuned via meta schedule. However, a small fraction of the time the script fails with the following error:

Traceback (most recent call last):
  File "min_error.py", line 23, in <module>
    rt_mod: tvm.module = tune_relay(
  File "/workspaces/tvm/python/tvm/meta_schedule/tune.py", line 713, in tune_relay
    task_scheduler.tune()
  File "/workspaces/tvm/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 61, in tune
    _ffi_api.TaskSchedulerTune(self)  # type: ignore # pylint: disable=no-member
  File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  8: TVMFuncCall
  7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  6: tvm::meta_schedule::TaskSchedulerNode::Tune()
  5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
  4: tvm::meta_schedule::RandomComputeLocationNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
  3: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
  2: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
  1: tvm::tir::ScheduleStateNode::DebugVerify() const
  0: tvm::tir::VerifyCachedFlags(tvm::tir::ScheduleState const&)
  File "/workspaces/tvm/src/tir/schedule/analysis/verify.cc", line 236
TVMError: Schedule verification failed. The IR is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(placeholder: T.Buffer[(1, 96, 1, 1, 4), "float32"], placeholder_1: T.Buffer[(96, 48, 3, 3, 4, 4), "float32"], placeholder_2: T.Buffer[(1, 48, 13, 13, 4), "float32"], T_relu: T.Buffer[(1, 96, 13, 13, 4), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        with T.block("root"):
            T.reads()
            T.writes()
            T.block_attr({"meta_schedule.parallel":32, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
            data_pad = T.alloc_buffer([1, 48, 15, 15, 4], dtype="float32")
            conv2d_NCHWc = T.alloc_buffer([1, 96, 13, 13, 4], dtype="float32")
            for i0_0, i1_0, i2_0, i3_0, i4_0, i0_1, i1_1, i2_1, i3_1, i4_1, i5_0, i6_0, i7_0 in T.grid(1, 3, 1, 1, 2, 1, 8, 1, 1, 2, 32, 1, 3):
                for ax0 in T.serial(1):
                    for ax1, ax2, ax3, ax4 in T.grid(T.min(47 - i5_0 * 6 // 4, (i5_0 * 6 % 4 + 5) // 4) + 1, 15, 13, 4):
                        with T.block("data_pad"):
                            i0 = T.axis.spatial(1, ax0)
                            i1 = T.axis.spatial(48, i5_0 * 6 // 4 + ax1)
                            i2 = T.axis.spatial(15, ax2)
                            i3 = T.axis.spatial(15, i7_0 + ax3)
                            i4 = T.axis.spatial(4, ax4)
                            T.reads(placeholder_2[i0, i1, T.max(i2 - 1, 0) : T.max(i2 - 1, 0) + (T.min(i2, 13) - T.max(i2 - 1, 0)), T.max(i3 - 1, 0) : T.max(i3 - 1, 0) + (T.min(i3, 13) - T.max(i3 - 1, 0)), i4])
                            T.writes(data_pad[i0, i1, i2, i3, i4])
                            data_pad[i0, i1, i2, i3, i4] = T.if_then_else(1 <= i2 and i2 < 14 and 1 <= i3 and i3 < 14, placeholder_2[i0, i1, i2 - 1, i3 - 1, i4], T.float32(0), dtype="float32")
                for i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 1, 13, 1, 6, 3, 1, 1, 4, 13, 1, 1):
                    with T.block("conv2d_NCHWc"):
                        n = T.axis.spatial(1, 0)
                        oc_chunk = T.axis.spatial(96, i1_0 * 32 + i1_1 * 4 + i1_3)
                        oh, ow = T.axis.remap("SS", [i2_3, i3_2])
                        oc_block = T.axis.spatial(4, i4_0 * 2 + i4_1)
                        ic = T.axis.reduce(192, i5_0 * 6 + i5_1)
                        kh, kw = T.axis.remap("RR", [i6_1, i7_0])
                        T.reads(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block], data_pad[n, ic // 4, oh + kh, ow + kw, ic % 4], placeholder_1[oc_chunk, ic // 4, kh, kw, ic % 4, oc_block])
                        T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
                        T.block_attr({"meta_schedule.tiling_structure":"SSRSRS", "workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 48, 13, 13, 4], "float32"], ["TENSOR", [96, 48, 3, 3, 4, 4], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW4c", "NCHW4c", "float32"]})
                        with T.init():
                            conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
                        conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 4, oh + kh, ow + kw, ic % 4] * placeholder_1[oc_chunk, ic // 4, kh, kw, ic % 4, oc_block]
            for i0, i1, i2, i3, i4 in T.grid(1, 96, 13, 13, 4):
                with T.block("T_relu"):
                    ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
                    T.reads(conv2d_NCHWc[ax0, ax1, ax2, ax3, ax4], placeholder[ax0, ax1, 0, 0, ax4])
                    T.writes(T_relu[ax0, ax1, ax2, ax3, ax4])
                    T_relu[ax0, ax1, ax2, ax3, ax4] = T.max(conv2d_NCHWc[ax0, ax1, ax2, ax3, ax4] + placeholder[ax0, ax1, 0, 0, ax4], T.float32(0))
    

The errors are:
- Wrong region_cover:  (conv2d_NCHWc, expected=0, actual=1)
- Wrong stage_pipeline:  (root, expected=0, actual=1)

I expect this might be due to random sampling within meta schedule, but would appreciate any help in getting around this.

Thanks!

Script causing the error:

import tempfile
import torch
import torchvision

import tvm
from tvm import relay
from tvm.target.target import Target
from tvm.meta_schedule import ReplayTraceConfig
from tvm.meta_schedule.tune import tune_relay


indata = torch.rand((1, 3, 224, 224))
model = torchvision.models.alexnet(pretrained=True)
model.eval()
scripted_model = torch.jit.trace(model, indata)
scripted_model.eval()

shape_list = [("input0", indata.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

with tempfile.TemporaryDirectory() as work_dir:
    target = Target("llvm --num-cores=2")
    rt_mod: tvm.module = tune_relay(
        mod=mod,
        params=params,
        target=target,
        config=ReplayTraceConfig(
            num_trials_per_iter=32,
            num_trials_total=32,
        ),
        work_dir=work_dir
    )

Configuration details:

CC @spectrometerHBH please take a look