code version: 0.21.0(latest release) torch version: 2.4.0+cu121
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
class RelaxModel(nn.Module):
def __init__(self):
super(RelaxModel, self).__init__()
self.conv = nn.Conv2D(
in_channels=128, out_channels=256, kernel_size=3, bias=False
)
def forward(self, x):
x = self.conv(x)
return x
input_shape = (512, 128, 68, 68)
mod, params = RelaxModel().export_tvm(
{"forward": {"x": nn.spec.Tensor(input_shape, "float32")}}
)
## Tuning
import tempfile
device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
trials = 2000
with target, tempfile.TemporaryDirectory() as tmp_dir:
mod = tvm.transform.Sequential(
[
relax.get_pipeline("zero"),
# Skip tuning if total_trials is 0
(
relax.transform.MetaScheduleTuneTIR(tmp_dir, trials)
if trials > 0
else tvm.transform.Sequential([])
),
relax.transform.MetaScheduleApplyDatabase(tmp_dir),
]
)(mod)
mod.show()
# deploy
import numpy as np
ex = tvm.compile(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.ones(input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.ones(p.shape).astype(p.dtype), dev) for _, p in params]
test = vm.time_evaluator(func_name="forward", dev=dev, min_repeat_ms=5000)
print(test(data, *gpu_params))
it does not work and throw an error for input_shape = (512, 128, 68, 68):
(tvm) [alen@node1 test]$ python test.py
2025-10-29 22:45:21 [INFO] Logging directory: /tmp/tmpr8tbwyfj/logs
2025-10-29 22:45:33 [INFO] LocalBuilder: max_workers = 128
2025-10-29 22:45:37 [INFO] LocalRunner: max_workers = 1
2025-10-29 22:45:41 [INFO] [task_scheduler.cc:165] Initializing Task #0: "main"
2025-10-29 22:45:42 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:45:48 [INFO] [task_scheduler.cc:266] Task #0 has finished. Remaining task(s): 0
2025-10-29 22:45:48 [INFO] [task_scheduler.cc:326]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
------------------------------------------------------------------------------------------------------------
0 | main | 1315467952128 | 1 | N/A | N/A | N/A | 0 | Y
------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0
[22:45:48] /home/tvm/src/relax/transform/meta_schedule.cc:89: Warning: Creating JSONDatabase. Workload at: /tmp/tmpr8tbwyfj/database_workload.json, Tuning records at: /tmp/tmpr8tbwyfj/database_tuning_record.json
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def conv2d(x: T.Buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(512), T.int64(256), T.int64(66), T.int64(66)), "float32")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)))
for i0, i1, i2, i3 in T.grid(T.int64(512), T.int64(128), T.int64(68), T.int64(68)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(x[v_i0, v_i1, v_i2, v_i3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, v_i3]
for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(512), T.int64(256), T.int64(66), T.int64(66), T.int64(128), T.int64(3), T.int64(3)):
with T.block("conv2d_nchw"):
v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_weight[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
with T.init():
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_weight[v_ff, v_rc, v_ry, v_rx]
@R.function
def forward(x: R.Tensor((512, 128, 68, 68), dtype="float32"), conv_weight: R.Tensor((256, 128, 3, 3), dtype="float32")) -> R.Tensor((512, 256, 66, 66), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
gv = R.call_tir(cls.conv2d, (x, conv_weight), out_sinfo=R.Tensor((512, 256, 66, 66), dtype="float32"))
R.output(gv)
return gv
Traceback (most recent call last):
File "/home/src/conv.py", line 153, in <module>
ex = tvm.compile(mod, target="cuda")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tvm/python/tvm/driver/build_module.py", line 104, in compile
return tvm.relax.build(
^^^^^^^^^^^^^^^^
File "/home/tvm/python/tvm/relax/vm_build.py", line 259, in build
return _vmlink(
^^^^^^^^
File "/home/tvm/python/tvm/relax/vm_build.py", line 154, in _vmlink
lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tvm/python/tvm/tir/build.py", line 173, in build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/home/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
File "/home/tvm/src/ir/transform.cc", line 551, in operator()
.set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 295, in tvm::transform::Pass::operator()(tvm::IRModule) const
return this->operator()(std::move(mod), PassContext::Current());
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 420, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
mod = pass_func(std::move(mod), pass_ctx);
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 545, in operator()
return pass_func(ffi::RValueRef<IRModule>(std::move(mod)), ctx);
^^^
File "tvm/ffi/cython/function.pxi", line 383, in tvm.ffi.core.tvm_ffi_callback
File "/home/tvm/python/tvm/tir/pipeline.py", line 122, in _pipeline
mod = tvm.ir.transform.Sequential(passes)(mod)
^^^^^^^
File "/home/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^
File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
File "/home/tvm/src/ir/transform.cc", line 551, in operator()
.set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 295, in tvm::transform::Pass::operator()(tvm::IRModule) const
return this->operator()(std::move(mod), PassContext::Current());
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 490, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
mod = pass(std::move(mod), pass_ctx);
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 311, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
ret = node->operator()(std::move(mod), pass_ctx);
^^^^^^^^^^^
File "/home/tvm/src/ir/transform.cc", line 420, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
mod = pass_func(std::move(mod), pass_ctx);
^^^^^^^^^^^
File "/home/tvm/src/tir/analysis/verify_memory.cc", line 203, in operator()
LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
^^^^^
RuntimeError: Memory verification failed with the following errors:
Variable `x` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `conv2d_nchw` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `conv_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Did you forget to bind?
# from tvm.script import tir as T
@T.prim_func
def conv2d(x: T.Buffer((T.int64(512), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(512), T.int64(256), T.int64(66), T.int64(66)), "float32")):
T.func_attr({"op_pattern": 4, "target": T.target({"arch": "sm_80", "host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-conda-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": True})
pad_temp = T.allocate([303038464], "float32", "global")
pad_temp_1 = T.Buffer((T.int64(303038464),), data=pad_temp)
for i0, i1, i2, i3 in T.grid(512, 128, 68, 68):
cse_v1: T.int32 = i0 * 591872 + i1 * 4624 + i2 * 68 + i3
x_1 = T.Buffer((T.int64(303038464),), data=x.data)
pad_temp_1[cse_v1] = x_1[cse_v1]
for nn, ff, yy, xx, rc, ry, rx in T.grid(512, 256, 66, 66, 128, 3, 3):
cse_v2: T.int32 = nn * 1115136 + ff * 4356 + yy * 66 + xx
conv2d_nchw_1 = T.Buffer((T.int64(570949632),), data=conv2d_nchw.data)
if rc == 0 and ry == 0 and rx == 0:
conv2d_nchw_1[cse_v2] = T.float32(0.0)
conv_weight_1 = T.Buffer((T.int64(294912),), data=conv_weight.data)
conv2d_nchw_1[cse_v2] = conv2d_nchw_1[cse_v2] + pad_temp_1[nn * 591872 + rc * 4624 + yy * 68 + ry * 68 + xx + rx] * conv_weight_1[ff * 1152 + rc * 9 + ry * 3 + rx]
for input_shape = (128, 128, 68, 68), the search stopped after only 64 trials, far fewer than the 2000 I set. However, no error was displayed:
(tvm) [alen@node1 test]$ python test.py
2025-10-29 22:33:26 [INFO] Logging directory: /tmp/tmpu9i7vb6z/logs
2025-10-29 22:33:38 [INFO] LocalBuilder: max_workers = 128
2025-10-29 22:33:42 [INFO] LocalRunner: max_workers = 1
2025-10-29 22:33:46 [INFO] [task_scheduler.cc:165] Initializing Task #0: "main"
2025-10-29 22:33:47 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:34:08 [INFO] [task_scheduler.cc:199] Sending 64 sample(s) to builder
2025-10-29 22:34:47 [INFO] [task_scheduler.cc:201] Sending 64 sample(s) to runner
2025-10-29 22:42:31 [DEBUG] XGB iter 0: tr-p-rmse: 0.590935 tr-a-peak@32: 0.955485 tr-rmse: 0.415917 tr-rmse: 0.415917
2025-10-29 22:42:31 [DEBUG] XGB iter 25: tr-p-rmse: 0.054546 tr-a-peak@32: 1.000000 tr-rmse: 0.454526 tr-rmse: 0.454526
2025-10-29 22:42:31 [DEBUG] XGB iter 50: tr-p-rmse: 0.054546 tr-a-peak@32: 1.000000 tr-rmse: 0.454526 tr-rmse: 0.454526
2025-10-29 22:42:31 [DEBUG] XGB stopped. Best iteration: [11] tr-p-rmse:0.05455 tr-a-peak@32:1.00000 tr-rmse:0.45453 tr-rmse:0.45453
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:243] [Updated] Task #0: "main"
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:326]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
-----------------------------------------------------------------------------------------------------------
0 | main | 328866988032 | 1 | 7501.8255 | 43838.2617 | 43838.2617 | 64 |
-----------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 43838.3
2025-10-29 22:42:31 [INFO] [task_scheduler.cc:186] TaskScheduler picks Task #0: "main"
2025-10-29 22:42:44 [INFO] [task_scheduler.cc:266] Task #0 has finished. Remaining task(s): 0
2025-10-29 22:42:44 [INFO] [task_scheduler.cc:326]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
-----------------------------------------------------------------------------------------------------------
0 | main | 328866988032 | 1 | 7501.8255 | 43838.2617 | 43838.2617 | 64 | Y
-----------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 43838.3
[22:42:44] /home/tvm/src/relax/transform/meta_schedule.cc:89: Warning: Creating JSONDatabase. Workload at: /tmp/tmpu9i7vb6z/database_workload.json, Tuning records at: /tmp/tmpu9i7vb6z/database_tuning_record.json
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def conv2d(x: T.Buffer((T.int64(128), T.int64(128), T.int64(68), T.int64(68)), "float32"), conv_weight: T.Buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(128), T.int64(256), T.int64(66), T.int64(66)), "float32")):
T.func_attr({"op_pattern": 4, "tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
conv2d_nchw_local = T.alloc_buffer((T.int64(128), T.int64(256), T.int64(66), T.int64(66)), scope="local")
pad_temp_shared = T.alloc_buffer((T.int64(128), T.int64(128), T.int64(68), T.int64(68)), scope="shared")
conv_weight_shared = T.alloc_buffer((T.int64(256), T.int64(128), T.int64(3), T.int64(3)), scope="shared")
for nn_0_ff_0_yy_0_xx_0_fused in T.thread_binding(T.int64(15488), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
for nn_1_ff_1_yy_1_xx_1_fused in T.thread_binding(T.int64(8), thread="vthread.x"):
for nn_2_ff_2_yy_2_xx_2_fused in T.thread_binding(T.int64(96), thread="threadIdx.x"):
for nn_3_init, ff_3_init, yy_3_init, xx_3_init, nn_4_init, ff_4_init, yy_4_init, xx_4_init in T.grid(T.int64(2), T.int64(1), T.int64(6), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
with T.block("conv2d_nchw_init"):
v_nn = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + nn_3_init + nn_4_init)
v_ff = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ff_3_init + ff_4_init)
v_yy = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + yy_3_init + yy_4_init)
v_xx = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + xx_3_init + xx_4_init)
T.reads()
T.writes(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx])
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"})
conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
for rc_0, ry_0, rx_0 in T.grid(T.int64(32), T.int64(1), T.int64(3)):
for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(32)):
for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(96), thread="threadIdx.x"):
with T.block("pad_temp_shared"):
v0 = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) // T.int64(192))
v1 = T.axis.spatial(T.int64(128), rc_0 * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(192) // T.int64(48))
v2 = T.axis.spatial(T.int64(68), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(48) // T.int64(6))
v3 = T.axis.spatial(T.int64(68), rx_0 + nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(96) + ax0_ax1_ax2_ax3_fused_1) % T.int64(6))
T.reads(x[v0, v1, v2, v3])
T.writes(pad_temp_shared[v0, v1, v2, v3])
pad_temp_shared[v0, v1, v2, v3] = x[v0, v1, v2, v3]
for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)):
for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(96), thread="threadIdx.x"):
for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)):
with T.block("conv.weight_shared"):
v0 = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(12))
v1 = T.axis.spatial(T.int64(128), rc_0 * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(12) // T.int64(3))
v2 = T.axis.spatial(T.int64(3), (ax0_ax1_ax2_ax3_fused_0 * T.int64(192) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(3))
v3 = T.axis.spatial(T.int64(3), rx_0)
T.reads(conv_weight[v0, v1, v2, v3])
T.writes(conv_weight_shared[v0, v1, v2, v3])
conv_weight_shared[v0, v1, v2, v3] = conv_weight[v0, v1, v2, v3]
for rc_1, ry_1, rx_1, nn_3, ff_3, yy_3, xx_3, rc_2, ry_2, rx_2, nn_4, ff_4, yy_4, xx_4 in T.grid(T.int64(4), T.int64(3), T.int64(1), T.int64(2), T.int64(1), T.int64(6), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
with T.block("conv2d_nchw_update"):
v_nn = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + nn_3 + nn_4)
v_ff = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ff_3 + ff_4)
v_yy = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + yy_3 + yy_4)
v_xx = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + xx_3 + xx_4)
v_rc = T.axis.reduce(T.int64(128), rc_0 * T.int64(4) + rc_1 + rc_2)
v_ry = T.axis.reduce(T.int64(3), ry_0 * T.int64(3) + ry_1 + ry_2)
v_rx = T.axis.reduce(T.int64(3), rx_0 + rx_1 + rx_2)
T.reads(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx], pad_temp_shared[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_weight_shared[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx])
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"})
conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_local[v_nn, v_ff, v_yy, v_xx] + pad_temp_shared[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_weight_shared[v_ff, v_rc, v_ry, v_rx]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1), T.int64(6), T.int64(1)):
with T.block("conv2d_nchw_local"):
v0 = T.axis.spatial(T.int64(128), nn_0_ff_0_yy_0_xx_0_fused // T.int64(1936) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused // T.int64(4) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused // T.int64(24) * T.int64(2) + ax0)
v1 = T.axis.spatial(T.int64(256), nn_0_ff_0_yy_0_xx_0_fused % T.int64(1936) // T.int64(121) * T.int64(16) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(4) // T.int64(2) * T.int64(8) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(24) // T.int64(3) + ax1)
v2 = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(121) // T.int64(11) * T.int64(6) + ax2)
v3 = T.axis.spatial(T.int64(66), nn_0_ff_0_yy_0_xx_0_fused % T.int64(11) * T.int64(6) + nn_1_ff_1_yy_1_xx_1_fused % T.int64(2) * T.int64(3) + nn_2_ff_2_yy_2_xx_2_fused % T.int64(3) + ax3)
T.reads(conv2d_nchw_local[v0, v1, v2, v3])
T.writes(conv2d_nchw[v0, v1, v2, v3])
conv2d_nchw[v0, v1, v2, v3] = conv2d_nchw_local[v0, v1, v2, v3]
@R.function
def forward(x: R.Tensor((128, 128, 68, 68), dtype="float32"), conv_weight: R.Tensor((256, 128, 3, 3), dtype="float32")) -> R.Tensor((128, 256, 66, 66), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
gv = R.call_tir(cls.conv2d, (x, conv_weight), out_sinfo=R.Tensor((128, 256, 66, 66), dtype="float32"))
R.output(gv)
return gv
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
28.1586 28.1586 28.1586 28.1586 0.0000