MetaScheduleTuneTIR does not work well

code version: 0.21.0(latest release) torch version: 2.4.0+cu121

import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn


class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.conv = nn.Conv2D(
            in_channels=128, out_channels=256, kernel_size=3, bias=False
        )

    def forward(self, x):
        x = self.conv(x)
        return x


input_shape = (512, 128, 68, 68)
mod, params = RelaxModel().export_tvm(
    {"forward": {"x": nn.spec.Tensor(input_shape, "float32")}}
)


## Tuning
import tempfile

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
trials = 2000
with target, tempfile.TemporaryDirectory() as tmp_dir:
    mod = tvm.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            # Skip tuning if total_trials is 0
            (
                relax.transform.MetaScheduleTuneTIR(tmp_dir, trials)
                if trials > 0
                else tvm.transform.Sequential([])
            ),
            relax.transform.MetaScheduleApplyDatabase(tmp_dir),
        ]
    )(mod)

mod.show()

# deploy
import numpy as np

ex = tvm.compile(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)

# Need to allocate data and params on GPU device
data = tvm.nd.array(np.ones(input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.ones(p.shape).astype(p.dtype), dev) for _, p in params]

test = vm.time_evaluator(func_name="forward", dev=dev, min_repeat_ms=5000)
print(test(data, *gpu_params))

it does not work and throw an error for input_shape = (512, 128, 68, 68):


(tvm) [alen@node1 test]$ python test.py 
2025-10-29 22:45:21 [INFO] Logging directory: /tmp/tmpr8tbwyfj/logs
2025-10-29 22:45:33 [INFO] LocalBuilder: max_workers = 128
2025-10-29 22:45:37 [INFO] LocalRunner: max_workers = 1
2025-10-29 22:45:41 [INFO] [task_scheduler.cc:165] Initializing Task #0: "main"
2025-10-29 22:45:42 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:45:48 [INFO] [task_scheduler.cc:266] Task #0 has finished. Remaining task(s): 0
2025-10-29 22:45:48 [INFO] [task_scheduler.cc:326] 
 ID | Name |          FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------------
  0 | main | 1315467952128 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

[22:45:48] /home/tvm/src/relax/transform/meta_schedule.cc:89: Warning: Creating JSONDatabase. Workload at: /tmp/tmpr8tbwyfj/database_workload.json, Tuning records at: /tmp/tmpr8tbwyfj/database_tuning_record.json
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d(x: T.Buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(512), T.int64(256), T.int64(66), T.int64(66)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.noalias": True})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)))
        for i0, i1, i2, i3 in T.grid(T.int64(512), T.int64(128), T.int64(68), T.int64(68)):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, v_i2, v_i3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, v_i3]
        for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(512), T.int64(256), T.int64(66), T.int64(66), T.int64(128), T.int64(3), T.int64(3)):
            with T.block("conv2d_nchw"):
                v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
                T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_weight[v_ff, v_rc, v_ry, v_rx])
                T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                with T.init():
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
                conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_weight[v_ff, v_rc, v_ry, v_rx]

    @R.function
    def forward(x: R.Tensor((512, 128, 68, 68), dtype="float32"), conv_weight: R.Tensor((256, 128, 3, 3), dtype="float32")) -> R.Tensor((512, 256, 66, 66), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.conv2d, (x, conv_weight), out_sinfo=R.Tensor((512, 256, 66, 66), dtype="float32"))
            R.output(gv)
        return gv

Traceback (most recent call last):
  File "/home/src/conv.py", line 153, in <module>
    ex = tvm.compile(mod, target="cuda")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/driver/build_module.py", line 104, in compile
    return tvm.relax.build(
           ^^^^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/relax/vm_build.py", line 259, in build
    return _vmlink(
           ^^^^^^^^
  File "/home/tvm/python/tvm/relax/vm_build.py", line 154, in _vmlink
    lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/tir/build.py", line 173, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/home/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
  File "/home/tvm/src/ir/transform.cc", line 551, in operator()
    .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 295, in tvm::transform::Pass::operator()(tvm::IRModule) const
    return this->operator()(std::move(mod), PassContext::Current());
  ^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 420, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass_func(std::move(mod), pass_ctx);
  ^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 545, in operator()
    return pass_func(ffi::RValueRef<IRModule>(std::move(mod)), ctx);
^^^
  File "tvm/ffi/cython/function.pxi", line 383, in tvm.ffi.core.tvm_ffi_callback
  File "/home/tvm/python/tvm/tir/pipeline.py", line 122, in _pipeline
    mod = tvm.ir.transform.Sequential(passes)(mod)
^^^^^^^
  File "/home/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
^^^^^^^
  File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
  File "/home/tvm/src/ir/transform.cc", line 551, in operator()
    .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 295, in tvm::transform::Pass::operator()(tvm::IRModule) const
    return this->operator()(std::move(mod), PassContext::Current());
  ^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 490, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
  File "/home/tvm/src/ir/transform.cc", line 420, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass_func(std::move(mod), pass_ctx);
  ^^^^^^^^^^^
  File "/home/tvm/src/tir/analysis/verify_memory.cc", line 203, in operator()
    LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
^^^^^
RuntimeError: Memory verification failed with the following errors:
    Variable `x` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `conv_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  Did you forget to bind?
# from tvm.script import tir as T

@T.prim_func
def conv2d(x: T.Buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(512), T.int64(256), T.int64(66), T.int64(66)), "float32")):
    T.func_attr({"op_pattern": 4, "target": T.target({"arch": "sm_80", "host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-conda-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
    pad_temp = T.allocate([303038464], "float32", "global")
    pad_temp_1 = T.Buffer((T.int64(303038464),), data=pad_temp)
    for i0, i1, i2, i3 in T.grid(512, 128, 68, 68):
        cse_v1: T.int32 = i0 * 591872 + i1 * 4624 + i2 * 68 + i3
        x_1 = T.Buffer((T.int64(303038464),), data=x.data)
        pad_temp_1[cse_v1] = x_1[cse_v1]
    for nn, ff, yy, xx, rc, ry, rx in T.grid(512, 256, 66, 66, 128, 3, 3):
        cse_v2: T.int32 = nn * 1115136 + ff * 4356 + yy * 66 + xx
        conv2d_nchw_1 = T.Buffer((T.int64(570949632),), data=conv2d_nchw.data)
        if rc == 0 and ry == 0 and rx == 0:
            conv2d_nchw_1[cse_v2] = T.float32(0.0)
        conv_weight_1 = T.Buffer((T.int64(294912),), data=conv_weight.data)
        conv2d_nchw_1[cse_v2] = conv2d_nchw_1[cse_v2] + pad_temp_1[nn * 591872 + rc * 4624 + yy * 68 + ry * 68 + xx + rx] * conv_weight_1[ff * 1152 + rc * 9 + ry * 3 + rx]

for input_shape = (128, 128, 68, 68), the search stopped after only 64 trials, far fewer than the 2000 I set. However, no error was displayed:

(tvm) [alen@node1 test]$ python test.py 
2025-10-29 22:33:26 [INFO] Logging directory: /tmp/tmpu9i7vb6z/logs
2025-10-29 22:33:38 [INFO] LocalBuilder: max_workers = 128
2025-10-29 22:33:42 [INFO] LocalRunner: max_workers = 1
2025-10-29 22:33:46 [INFO] [task_scheduler.cc:165] Initializing Task #0: "main"
2025-10-29 22:33:47 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:34:08 [INFO] [task_scheduler.cc:199] Sending 64 sample(s) to builder
2025-10-29 22:34:47 [INFO] [task_scheduler.cc:201] Sending 64 sample(s) to runner
2025-10-29 22:42:31 [DEBUG] XGB iter   0: tr-p-rmse: 0.590935   tr-a-peak@32: 0.955485  tr-rmse: 0.415917       tr-rmse: 0.415917
2025-10-29 22:42:31 [DEBUG] XGB iter  25: tr-p-rmse: 0.054546   tr-a-peak@32: 1.000000  tr-rmse: 0.454526       tr-rmse: 0.454526
2025-10-29 22:42:31 [DEBUG] XGB iter  50: tr-p-rmse: 0.054546   tr-a-peak@32: 1.000000  tr-rmse: 0.454526       tr-rmse: 0.454526
2025-10-29 22:42:31 [DEBUG] XGB stopped. Best iteration: [11] tr-p-rmse:0.05455 tr-a-peak@32:1.00000    tr-rmse:0.45453 tr-rmse:0.45453 
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:243] [Updated] Task #0: "main"
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:326] 
 ID | Name |         FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------------
  0 | main | 328866988032 |      1 |      7501.8255 |   43838.2617 |            43838.2617 |     64 |      
-----------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 43838.3

2025-10-29 22:42:31 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:42:44 [INFO] [task_scheduler.cc:266] Task #0 has finished. Remaining task(s): 0
2025-10-29 22:42:44 [INFO] [task_scheduler.cc:326] 
 ID | Name |         FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------------
  0 | main | 328866988032 |      1 |      7501.8255 |   43838.2617 |            43838.2617 |     64 |    Y 
-----------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 43838.3

[22:42:44] /home/tvm/src/relax/transform/meta_schedule.cc:89: Warning: Creating JSONDatabase. Workload at: /tmp/tmpu9i7vb6z/database_workload.json, Tuning records at: /tmp/tmpu9i7vb6z/database_tuning_record.json
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d(x: T.Buffer((T.int64(128), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(128), T.int64(256), T.int64(66), T.int64(66)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.is_scheduled": True, "tir.noalias": True})
        # with T.block("root"):
        conv2d_nchw_local = T.alloc_buffer((T.int64(128), T.int64(256), T.int64(66), T.int64(66)), scope="local")
        pad_temp_shared = T.alloc_buffer((T.int64(128), T.int64(128), T.int64(68), T.int64(68)), scope="shared")
        conv_weight_shared = T.alloc_buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), scope="shared")
        for nn_0_ff_0_yy_0_xx_0_fused in T.thread_binding(T.int64(15488), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
            for nn_1_ff_1_yy_1_xx_1_fused in T.thread_binding(T.int64(8), thread="vthread.x"):
                for nn_2_ff_2_yy_2_xx_2_fused in T.thread_binding(T.int64(96), thread="threadIdx.x"):
                    for nn_3_init, ff_3_init, yy_3_init, xx_3_init, nn_4_init, ff_4_init, yy_4_init, xx_4_init in T.grid(T.int64(2), T.int64(1), T.int64(6), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
                        with T.block("conv2d_nchw_init"):
                            v_nn = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + nn_3_init + nn_4_init)
                            v_ff = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ff_3_init + ff_4_init)
                            v_yy = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + yy_3_init + yy_4_init)
                            v_xx = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + xx_3_init + xx_4_init)
                            T.reads()
                            T.writes(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx])
                            T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"})
                            conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
                    for rc_0, ry_0, rx_0 in T.grid(T.int64(32), T.int64(1), T.int64(3)):
                        for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(32)):
                            for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(96), thread="threadIdx.x"):
                                with T.block("pad_temp_shared"):
                                    v0 = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) // T.int64(192))
                                    v1 = T.axis.spatial(T.int64(128), rc_0 * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(192) // T.int64(48))
                                    v2 = T.axis.spatial(T.int64(68), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(48) // T.int64(6))
                                    v3 = T.axis.spatial(T.int64(68), rx_0 + nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(6))
                                    T.reads(x[v0, v1, v2, v3])
                                    T.writes(pad_temp_shared[v0, v1, v2, v3])
                                    pad_temp_shared[v0, v1, v2, v3] = x[v0, v1, v2, v3]
                        for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)):
                            for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(96), thread="threadIdx.x"):
                                for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)):
                                    with T.block("conv.weight_shared"):
                                        v0 = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(12))
                                        v1 = T.axis.spatial(T.int64(128), rc_0 * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(12) // T.int64(3))
                                        v2 = T.axis.spatial(T.int64(3), (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(3))
                                        v3 = T.axis.spatial(T.int64(3), rx_0)
                                        T.reads(conv_weight[v0, v1, v2, v3])
                                        T.writes(conv_weight_shared[v0, v1, v2, v3])
                                        conv_weight_shared[v0, v1, v2, v3] = conv_weight[v0, v1, v2, v3]
                        for rc_1, ry_1, rx_1, nn_3, ff_3, yy_3, xx_3, rc_2, ry_2, rx_2, nn_4, ff_4, yy_4, xx_4 in T.grid(T.int64(4), T.int64(3), T.int64(1), T.int64(2), T.int64(1), T.int64(6), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
                            with T.block("conv2d_nchw_update"):
                                v_nn = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + nn_3 + nn_4)
                                v_ff = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ff_3 + ff_4)
                                v_yy = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + yy_3 + yy_4)
                                v_xx = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + xx_3 + xx_4)
                                v_rc = T.axis.reduce(T.int64(128), rc_0 * T.int64(4) + rc_1 + rc_2)
                                v_ry = T.axis.reduce(T.int64(3), ry_0 * T.int64(3) + ry_1 + ry_2)
                                v_rx = T.axis.reduce(T.int64(3), rx_0 + rx_1 + rx_2)
                                T.reads(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx], pad_temp_shared[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_weight_shared[v_ff, v_rc, v_ry, v_rx])
                                T.writes(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx])
                                T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"})
                                conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] + pad_temp_shared[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_weight_shared[v_ff, v_rc, v_ry, v_rx]
                    for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1), T.int64(6), T.int64(1)):
                        with T.block("conv2d_nchw_local"):
                            v0 = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + ax0)
                            v1 = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ax1)
                            v2 = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + ax2)
                            v3 = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + ax3)
                            T.reads(conv2d_nchw_local[v0, v1, v2, v3])
                            T.writes(conv2d_nchw[v0, v1, v2, v3])
                            conv2d_nchw[v0, v1, v2, v3] = conv2d_nchw_local[v0, v1, v2, v3]

    @R.function
    def forward(x: R.Tensor((128, 128, 68, 68), dtype="float32"), conv_weight: R.Tensor((256, 128, 3, 3), dtype="float32")) -> R.Tensor((128, 256, 66, 66), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.conv2d, (x, conv_weight), out_sinfo=R.Tensor((128, 256, 66, 66), dtype="float32"))
            R.output(gv)
        return gv

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  28.1586      28.1586      28.1586      28.1586       0.0000