[Bug][MetaSchedule]Encountered an error while tuning R.nn.attention using meta_schedule

import tvm
from tvm import meta_schedule as ms
from tvm import relax

from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I

from typing import List


@tvm.script.ir_module
class Attention:
    @R.function
    def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 32, 8), "float32"), v: R.Tensor((4, 8, 32, 16), "float32"), bias: R.Tensor((4, 32, 16, 8), "float32")):
        scale = T.FloatImm("float32", 0.1)
        gv: R.Tensor((4, 16, 32, 16), "float32") = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="TopLeft")
        return gv

mod: tvm.IRModule = Attention

mod.show()

target = tvm.target.Target("nvidia/geforce-rtx-3060")
dev = tvm.cuda(0)

seq = tvm.transform.Sequential(
    [
        relax.get_pipeline("zero"),
        relax.transform.DeadCodeElimination(),
    ]
)
mod = seq(mod)

mod.show()

# with target:
#     mod = tvm.tir.transform.DefaultGPUSchedule()(mod)

database = ms.tune_tir(
    mod,
    target,
    "./tuning_logs",
    10,
)

with target:
    mod = relax.transform.MetaScheduleApplyDatabase("./tuning_logs")(mod)

mod.show()

I encountered an error when using tune_tir to tune an IRModule that only contains R.nn.attention .

Traceback (most recent call last):
  File "attention.py", line 40, in <module>
    database = ms.tune_tir(
  File "/home/yrx/develop/my/tvm/python/tvm/meta_schedule/tir_integration.py", line 145, in tune_tir
    return tune_tasks(
  File "/home/yrx/develop/my/tvm/python/tvm/meta_schedule/tune.py", line 118, in tune_tasks
    task_scheduler.tune(
  File "/home/yrx/develop/my/tvm/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
    _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 284, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/yrx/develop/my/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  6: _ZN3tvm7runtime13PackedFuncObj
  5: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  4: tvm::meta_schedule::GradientBasedNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
  3: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
  2: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
  1: tvm::meta_schedule::CrossThreadReductionNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
  0: tvm::meta_schedule::CrossThreadReductionNode::GetThreadIdxExtentFromTrace(tvm::tir::Trace const&)
  File "/home/yrx/develop/my/tvm/include/tvm/runtime/container/array.h", line 414
InternalError: Check failed: (0 <= i && i < p->size_) is false: IndexError: indexing 6 on an array of size 6