MetaScheduleTuneTIR does not work well

code version: 0.21.0(latest release) torch version: 2.4.0+cu121

import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn


class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.conv = nn.Conv2D(
            in_channels=128, out_channels=256, kernel_size=3, bias=False
        )

    def forward(self, x):
        x = self.conv(x)
        return x


input_shape = (512, 128, 68, 68)
mod, params = RelaxModel().export_tvm(
    {"forward": {"x": nn.spec.Tensor(input_shape, "float32")}}
)


## Tuning
import tempfile

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
trials = 2000
with target, tempfile.TemporaryDirectory() as tmp_dir:
    mod = tvm.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            # Skip tuning if total_trials is 0
            (
                relax.transform.MetaScheduleTuneTIR(tmp_dir, trials)
                if trials > 0
                else tvm.transform.Sequential([])
            ),
            relax.transform.MetaScheduleApplyDatabase(tmp_dir),
        ]
    )(mod)

mod.show()

# deploy
import numpy as np

ex = tvm.compile(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)

# Need to allocate data and params on GPU device
data = tvm.nd.array(np.ones(input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.ones(p.shape).astype(p.dtype), dev) for _, p in params]

test = vm.time_evaluator(func_name="forward", dev=dev, min_repeat_ms=5000)
print(test(data, *gpu_params))

it does not work and throw an error for input_shape = (512, 128, 68, 68):


(tvm) [alen@node1 test]$ python test.py 
2025-10-29 22:45:21 [INFO] Logging directory: /tmp/tmpr8tbwyfj/logs
2025-10-29 22:45:33 [INFO] LocalBuilder: max_workers = 128
2025-10-29 22:45:37 [INFO] LocalRunner: max_workers = 1
2025-10-29 22:45:41 [INFO] [task_scheduler.cc:165] Initializing Task #0: "main"
2025-10-29 22:45:42 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:45:48 [INFO] [task_scheduler.cc:266] Task #0 has finished. Remaining task(s): 0
2025-10-29 22:45:48 [INFO] [task_scheduler.cc:326] 
 ID | Name |          FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------------
  0 | main | 1315467952128 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

[22:45:48] /home/tvm/src/relax/transform/meta_schedule.cc:89: Warning: Creating JSONDatabase. Workload at: /tmp/tmpr8tbwyfj/database_workload.json, Tuning records at: /tmp/tmpr8tbwyfj/database_tuning_record.json
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d(x: T.Buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(512), T.int64(256), T.int64(66), T.int64(66)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.noalias": True})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)))
        for i0, i1, i2, i3 in T.grid(T.int64(512), T.int64(128), T.int64(68), T.int64(68)):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, v_i2, v_i3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, v_i3]
        for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(512), T.int64(256), T.int64(66), T.int64(66), T.int64(128), T.int64(3), T.int64(3)):
            with T.block("conv2d_nchw"):
                v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
                T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_weight[v_ff, v_rc, v_ry, v_rx])
                T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                with T.init():
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
                conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_weight[v_ff, v_rc, v_ry, v_rx]

    @R.function
    def forward(x: R.Tensor((512, 128, 68, 68), dtype="float32"), conv_weight: R.Tensor((256, 128, 3, 3), dtype="float32")) -> R.Tensor((512, 256, 66, 66), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.conv2d, (x, conv_weight), out_sinfo=R.Tensor((512, 256, 66, 66), dtype="float32"))
            R.output(gv)
        return gv

Traceback (most recent call last):
  File "/home/src/conv.py", line 153, in <module>
    ex = tvm.compile(mod, target="cuda")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/driver/build_module.py", line 104, in compile
    return tvm.relax.build(
           ^^^^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/relax/vm_build.py", line 259, in build
    return _vmlink(
           ^^^^^^^^
  File "/home/tvm/python/tvm/relax/vm_build.py", line 154, in _vmlink
    lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/tir/build.py", line 173, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
  File "/home/tvm/src/ir/transform.cc", line 551, in operator()
    .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 295, in tvm::transform::Pass::operator()(tvm::IRModule) const
    return this->operator()(std::move(mod), PassContext::Current());
  ^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 420, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass_func(std::move(mod), pass_ctx);
  ^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 545, in operator()
    return pass_func(ffi::RValueRef<IRModule>(std::move(mod)), ctx);
^^^
  File "tvm/ffi/cython/function.pxi", line 383, in tvm.ffi.core.tvm_ffi_callback
  File "/home/tvm/python/tvm/tir/pipeline.py", line 122, in _pipeline
    mod = tvm.ir.transform.Sequential(passes)(mod)
^^^^^^^
  File "/home/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
^^^^^^^
  File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
  File "/home/tvm/src/ir/transform.cc", line 551, in operator()
    .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 295, in tvm::transform::Pass::operator()(tvm::IRModule) const
    return this->operator()(std::move(mod), PassContext::Current());
  ^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 490, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 420, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass_func(std::move(mod), pass_ctx);
  ^^^^^^^^^^^
  File "/home/tvm/src/tir/analysis/verify_memory.cc", line 203, in operator()
    LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
^^^^^
RuntimeError: Memory verification failed with the following errors:
    Variable `x` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  Did you forget to bind?
# from tvm.script import tir as T

@T.prim_func
def conv2d(x: T.Buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(512), T.int64(256), T.int64(66), T.int64(66)), "float32")):
    T.func_attr({"op_pattern": 4, "target": T.target({"arch": "sm_80", "host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-conda-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
    pad_temp = T.allocate([303038464], "float32", "global")
    pad_temp_1 = T.Buffer((T.int64(303038464),), data=pad_temp)
    for i0, i1, i2, i3 in T.grid(512, 128, 68, 68):
        cse_v1: T.int32 = i0 * 591872 + i1 * 4624 + i2 * 68 + i3
        x_1 = T.Buffer((T.int64(303038464),), data=x.data)
        pad_temp_1[cse_v1] = x_1[cse_v1]
    for nn, ff, yy, xx, rc, ry, rx in T.grid(512, 256, 66, 66, 128, 3, 3):
        cse_v2: T.int32 = nn * 1115136 + ff * 4356 + yy * 66 + xx
        conv2d_nchw_1 = T.Buffer((T.int64(570949632),), data=conv2d_nchw.data)
        if rc == 0 and ry == 0 and rx == 0:
            conv2d_nchw_1[cse_v2] = T.float32(0.0)
        conv_weight_1 = T.Buffer((T.int64(294912),), data=conv_weight.data)
        conv2d_nchw_1[cse_v2] = conv2d_nchw_1[cse_v2] + pad_temp_1[nn * 591872 + rc * 4624 + yy * 68 + ry * 68 + xx + rx] * conv_weight_1[ff * 1152 + rc * 9 + ry * 3 + rx]

for input_shape = (128, 128, 68, 68), the search stopped after only 64 trials, far fewer than the 2000 I set. However, no error was displayed:

(tvm) [alen@node1 test]$ python test.py 
2025-10-29 22:33:26 [INFO] Logging directory: /tmp/tmpu9i7vb6z/logs
2025-10-29 22:33:38 [INFO] LocalBuilder: max_workers = 128
2025-10-29 22:33:42 [INFO] LocalRunner: max_workers = 1
2025-10-29 22:33:46 [INFO] [task_scheduler.cc:165] Initializing Task #0: "main"
2025-10-29 22:33:47 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:34:08 [INFO] [task_scheduler.cc:199] Sending 64 sample(s) to builder
2025-10-29 22:34:47 [INFO] [task_scheduler.cc:201] Sending 64 sample(s) to runner
2025-10-29 22:42:31 [DEBUG] XGB iter   0: tr-p-rmse: 0.590935   tr-a-peak@32: 0.955485  tr-rmse: 0.415917       tr-rmse: 0.415917
2025-10-29 22:42:31 [DEBUG] XGB iter  25: tr-p-rmse: 0.054546   tr-a-peak@32: 1.000000  tr-rmse: 0.454526       tr-rmse: 0.454526
2025-10-29 22:42:31 [DEBUG] XGB iter  50: tr-p-rmse: 0.054546   tr-a-peak@32: 1.000000  tr-rmse: 0.454526       tr-rmse: 0.454526
2025-10-29 22:42:31 [DEBUG] XGB stopped. Best iteration: [11] tr-p-rmse:0.05455 tr-a-peak@32:1.00000    tr-rmse:0.45453 tr-rmse:0.45453 
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:243] [Updated] Task #0: "main"
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:326] 
 ID | Name |         FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------------
  0 | main | 328866988032 |      1 |      7501.8255 |   43838.2617 |            43838.2617 |     64 |      
-----------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 43838.3

2025-10-29 22:42:31 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:42:44 [INFO] [task_scheduler.cc:266] Task #0 has finished. Remaining task(s): 0
2025-10-29 22:42:44 [INFO] [task_scheduler.cc:326] 
 ID | Name |         FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------------
  0 | main | 328866988032 |      1 |      7501.8255 |   43838.2617 |            43838.2617 |     64 |    Y 
-----------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 43838.3

[22:42:44] /home/tvm/src/relax/transform/meta_schedule.cc:89: Warning: Creating JSONDatabase. Workload at: /tmp/tmpu9i7vb6z/database_workload.json, Tuning records at: /tmp/tmpu9i7vb6z/database_tuning_record.json
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d(x: T.Buffer((T.int64(128), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(128), T.int64(256), T.int64(66), T.int64(66)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.is_scheduled": True, "tir.noalias": True})
        # with T.block("root"):
        conv2d_nchw_local = T.alloc_buffer((T.int64(128), T.int64(256), T.int64(66), T.int64(66)), scope="local")
        pad_temp_shared = T.alloc_buffer((T.int64(128), T.int64(128), T.int64(68), T.int64(68)), scope="shared")
        conv_weight_shared = T.alloc_buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), scope="shared")
        for nn_0_ff_0_yy_0_xx_0_fused in T.thread_binding(T.int64(15488), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for nn_1_ff_1_yy_1_xx_1_fused in T.thread_binding(T.int64(8), thread="vthread.x"):
                for nn_2_ff_2_yy_2_xx_2_fused in T.thread_binding(T.int64(96), thread="threadIdx.x"):
                    for nn_3_init, ff_3_init, yy_3_init, xx_3_init, nn_4_init, ff_4_init, yy_4_init, xx_4_init in T.grid(T.int64(2), T.int64(1), T.int64(6), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
                        with T.block("conv2d_nchw_init"):
                            v_nn = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + nn_3_init + nn_4_init)
                            v_ff = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ff_3_init + ff_4_init)
                            v_yy = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + yy_3_init + yy_4_init)
                            v_xx = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + xx_3_init + xx_4_init)
                            T.reads()
                            T.writes(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx])
                            T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"})
                            conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
                    for rc_0, ry_0, rx_0 in T.grid(T.int64(32), T.int64(1), T.int64(3)):
                        for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(32)):
                            for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(96), thread="threadIdx.x"):
                                with T.block("pad_temp_shared"):
                                    v0 = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) // T.int64(192))
                                    v1 = T.axis.spatial(T.int64(128), rc_0 * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(192) // T.int64(48))
                                    v2 = T.axis.spatial(T.int64(68), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(48) // T.int64(6))
                                    v3 = T.axis.spatial(T.int64(68), rx_0 + nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(6))
                                    T.reads(x[v0, v1, v2, v3])
                                    T.writes(pad_temp_shared[v0, v1, v2, v3])
                                    pad_temp_shared[v0, v1, v2, v3] = x[v0, v1, v2, v3]
                        for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)):
                            for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(96), thread="threadIdx.x"):
                                for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)):
                                    with T.block("conv.weight_shared"):
                                        v0 = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(12))
                                        v1 = T.axis.spatial(T.int64(128), rc_0 * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(12) // T.int64(3))
                                        v2 = T.axis.spatial(T.int64(3), (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(3))
                                        v3 = T.axis.spatial(T.int64(3), rx_0)
                                        T.reads(conv_weight[v0, v1, v2, v3])
                                        T.writes(conv_weight_shared[v0, v1, v2, v3])
                                        conv_weight_shared[v0, v1, v2, v3] = conv_weight[v0, v1, v2, v3]
                        for rc_1, ry_1, rx_1, nn_3, ff_3, yy_3, xx_3, rc_2, ry_2, rx_2, nn_4, ff_4, yy_4, xx_4 in T.grid(T.int64(4), T.int64(3), T.int64(1), T.int64(2), T.int64(1), T.int64(6), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
                            with T.block("conv2d_nchw_update"):
                                v_nn = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + nn_3 + nn_4)
                                v_ff = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ff_3 + ff_4)
                                v_yy = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + yy_3 + yy_4)
                                v_xx = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + xx_3 + xx_4)
                                v_rc = T.axis.reduce(T.int64(128), rc_0 * T.int64(4) + rc_1 + rc_2)
                                v_ry = T.axis.reduce(T.int64(3), ry_0 * T.int64(3) + ry_1 + ry_2)
                                v_rx = T.axis.reduce(T.int64(3), rx_0 + rx_1 + rx_2)
                                T.reads(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx], pad_temp_shared[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_weight_shared[v_ff, v_rc, v_ry, v_rx])
                                T.writes(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx])
                                T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"})
                                conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] + pad_temp_shared[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_weight_shared[v_ff, v_rc, v_ry, v_rx]
                    for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1), T.int64(6), T.int64(1)):
                        with T.block("conv2d_nchw_local"):
                            v0 = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + ax0)
                            v1 = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ax1)
                            v2 = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + ax2)
                            v3 = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + ax3)
                            T.reads(conv2d_nchw_local[v0, v1, v2, v3])
                            T.writes(conv2d_nchw[v0, v1, v2, v3])
                            conv2d_nchw[v0, v1, v2, v3] = conv2d_nchw_local[v0, v1, v2, v3]

    @R.function
    def forward(x: R.Tensor((128, 128, 68, 68), dtype="float32"), conv_weight: R.Tensor((256, 128, 3, 3), dtype="float32")) -> R.Tensor((128, 256, 66, 66), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.conv2d, (x, conv_weight), out_sinfo=R.Tensor((128, 256, 66, 66), dtype="float32"))
            R.output(gv)
        return gv

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  28.1586      28.1586      28.1586      28.1586       0.0000

I tried the latest version(branch of main, commit f532b8), and got the same error :sob:

I’m not sure about the error, but the number of trials you set is might be the global trials among all tasks. The default maximum trials per iteration is 64 so that could be the reason, but you can change that too if you want with num_trials_per_iter. Maybe for the error try using a different variable in the forward function?

out = self.conv(x) return out

Thanks for your advice, but it got the same error. I guess that certain axis bindings violate the required constraints.