Hello,
Any ideas as to what is wrong ? The example script should surely work with a RTX 4090
target.
As an alternative, specifying my RTX 4090
target in the pattern of the example’s original RTX 3090
target using the Target constructor rather than the tvm.target.cuda()
function results in the following error:
RuntimeError: Memory verification failed with the following errors:
...
Did you forget to bind?
Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
... several more variables listed for the same problem
at the following statement on line 56 of the script:
target = tvm.target.Target("nvidia/geforce-rtx-4090", host="cuda") # Change to your target device
So, I’m unable to execute the end-to-end example when specifying a RTX 4090
target. Should I file a bug ?
I built TVM from the latest source using the variable settings from tvm.support.libinfo()
below. A listing of the script and the corresponding output follow the variable settings.
USE_NVTX: OFF
USE_GTEST: AUTO
SUMMARIZE: OFF
TVM_DEBUG_WITH_ABI_CHANGE: OFF
USE_IOS_RPC: OFF
USE_MSC: OFF
CUDA_VERSION: 12.8
USE_LIBBACKTRACE: AUTO
DLPACK_PATH: 3rdparty/dlpack/include
USE_TENSORRT_CODEGEN: OFF
USE_OPENCL_EXTN_QCOM: NOT-FOUND
USE_THRUST: OFF
BUILD_DUMMY_LIBTVM: OFF
USE_CUDNN: ON
USE_TENSORRT_RUNTIME: OFF
USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR: OFF
USE_CCACHE: AUTO
USE_ARM_COMPUTE_LIB: OFF
USE_CPP_RTVM: OFF
USE_OPENCL_GTEST: /path/to/opencl/gtest
TVM_LOG_BEFORE_THROW: OFF
USE_MKL: OFF
MLIR_VERSION: NOT-FOUND
USE_CLML: OFF
USE_STACKVM_RUNTIME: OFF
ROCM_PATH: /opt/rocm
USE_DNNL: OFF
USE_MSCCL: OFF
USE_NNAPI_RUNTIME: OFF
USE_MLIR: OFF
USE_RCCL: OFF
USE_LLVM: llvm-config --ignore-libllvm --link-static
USE_THREADS: ON
USE_MSVC_MT: OFF
BACKTRACE_ON_SEGFAULT: OFF
USE_ROCBLAS: OFF
USE_NCCL: OFF
GIT_COMMIT_HASH: 0d42dc411d8729a97a7d6cc89266da2b50045897
USE_VULKAN: OFF
USE_RUST_EXT: OFF
USE_CUTLASS: OFF
USE_CPP_RPC: OFF
USE_HEXAGON: OFF
USE_CUSTOM_LOGGING: OFF
USE_UMA: OFF
USE_FALLBACK_STL_MAP: OFF
USE_SORT: ON
USE_RTTI: ON
GIT_COMMIT_TIME: 2025-02-18 10:30:10 -0500
USE_HIPBLAS: OFF
USE_HEXAGON_SDK: /path/to/sdk
USE_BLAS: none
USE_LIBTORCH: OFF
USE_RANDOM: ON
USE_CUDA: ON
USE_COREML: OFF
USE_AMX: OFF
BUILD_STATIC_RUNTIME: OFF
USE_KHRONOS_SPIRV: OFF
USE_CLML_GRAPH_EXECUTOR: OFF
USE_TFLITE: OFF
USE_HEXAGON_GTEST: /path/to/hexagon/gtest
PICOJSON_PATH: 3rdparty/picojson
USE_OPENCL_ENABLE_HOST_PTR: OFF
INSTALL_DEV: OFF
USE_NNPACK: OFF
LLVM_VERSION: 19.1.7
USE_MRVL: OFF
USE_OPENCL: OFF
COMPILER_RT_PATH: 3rdparty/compiler-rt
USE_NNAPI_CODEGEN: OFF
RANG_PATH: 3rdparty/rang/include
USE_SPIRV_KHR_INTEGER_DOT_PRODUCT: OFF
USE_OPENMP: none
USE_BNNS: OFF
USE_FLASHINFER: OFF
USE_CUBLAS: ON
USE_METAL: OFF
USE_HEXAGON_EXTERNAL_LIBS: OFF
USE_ALTERNATIVE_LINKER: AUTO
USE_BYODT_POSIT: OFF
USE_NVSHMEM: OFF
USE_HEXAGON_RPC: OFF
DMLC_PATH: 3rdparty/dmlc-core/include
INDEX_DEFAULT_I64: ON
USE_RPC: ON
USE_TENSORFLOW_PATH: none
TVM_CLML_VERSION:
USE_MIOPEN: OFF
USE_ROCM: OFF
USE_PAPI: OFF
USE_CURAND: OFF
TVM_CXX_COMPILER_PATH: /usr/bin/c++
HIDE_PRIVATE_SYMBOLS: ON
Here is the script:
1 # Licensed to the Apache Software Foundation (ASF) under one
2 # or more contributor license agreements. See the NOTICE file
3 # distributed with this work for additional information
4 # regarding copyright ownership. The ASF licenses this file
5 # to you under the Apache License, Version 2.0 (the
6 # "License"); you may not use this file except in compliance
7 # with the License. You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing,
12 # software distributed under the License is distributed on an
13 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 # KIND, either express or implied. See the License for the
15 # specific language governing permissions and limitations
16 # under the License.
17
18 import os
19 import numpy as np
20 import torch
21 from torch.export import export
22 from torchvision.models.resnet import ResNet18_Weights, resnet18
23
24 import tvm
25 from tvm import relax
26 from tvm.relax.frontend.torch import from_exported_program
27
28 print("\n##### TVM and CUDA Checks #####")
29 print(f"TVM File: {tvm.__file__}")
30 print(f"TVM base lib: {tvm._ffi.base._LIB}")
31 print(f"Cuda exists?: {tvm.cuda().exist}")
32 print("################################\n")
33
34 torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
35
36 # Give an example argument to torch.export
37 example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
38
39 # Convert the model to IRModule
40 with torch.no_grad():
41 exported_program = export(torch_model, example_args)
42 mod = from_exported_program(exported_program, keep_params_as_input=True)
43
44 mod, params = relax.frontend.detach_params(mod)
45 # mod.show()
46
47 TOTAL_TRIALS = 1 # Change to 20000 for better performance if needed
48
49 # target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device
50
51 # The following target specification generates an RuntimeError
52 # RuntimeError: Memory verification failed with the following errors:
53 # ...
54 # Did you forget to bind?
55 # Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
56 target = tvm.target.Target("nvidia/geforce-rtx-4090", host="cuda") # Change to your target device
57
58 # The following target specification generates a ValueError
59 # ValueError: Check failed: (max_threads_per_block.defined()) is false: missing attribute `max_threads_per_block` in the target
60 # target = tvm.target.cuda(model="4090", arch="sm_89");
61
62 print("\n##### Target Info #####")
63 print(f"Device Type: {target.get_target_device_type()}")
64 print(f"Model: {target.model}")
65 print(f"Arch: {target.arch}")
66 print(f"Max Threads: {target.max_num_threads}")
67 # The following two lines generate the error show below
68 # print(f"Max Block Size x: {target.max_block_size_x}")
69 # print(f"Max Block Size y: {target.max_block_size_y}")
70 # Generates error:
71 # return int(self.attrs["max_block_size_x"])
72 # ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
73 # InternalError: Check failed: (it != n->end()) is false: cannot find the corresponding key in the Map
74 #
75 print(f"Thread Warp Size: {target.thread_warp_size}")
76 print("#########################\n")
77
78 work_dir = "tuning_logs"
79
80 # Skip running in CI environment
81 IS_IN_CI = os.getenv("CI", "") == "true"
82 if not IS_IN_CI:
83 mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod)
84
85 # Only show the main function
86 mod["main"].show()
87
88 if not IS_IN_CI:
89 ex = relax.build(mod, target="cuda")
90 dev = tvm.device("cuda", 0)
91 vm = relax.VirtualMachine(ex, dev)
92 # Need to allocate data and params on GPU device
93 gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
94 gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
95 gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
96
97 print(gpu_out.shape)
Here is the excerpted output to fit under the maximum posting size:
1
2 ##### TVM and CUDA Checks #####
3 TVM File: /home/rbohl/work/ai_tools/tvm/python/tvm/__init__.py
4 TVM base lib: <CDLL '/home/rbohl/work/ai_tools/tvm/build/libtvm.so', handle 63be0f4cda10 at 0x74401d268310>
5 Cuda exists?: True
6 ################################
7
8
9 ##### Target Info #####
10 Device Type: 2
11 Model: unknown
12 Arch: sm_89
13 Max Threads: 1024
14 Thread Warp Size: 32
15 #########################
16
17 2025-02-27 10:05:28 [INFO] Logging directory: tuning_logs/logs
18 2025-02-27 10:05:35 [INFO] LocalBuilder: max_workers = 8
19 2025-02-27 10:05:36 [INFO] LocalRunner: max_workers = 1
20 2025-02-27 10:05:40 [INFO] [task_scheduler.cc:159] Initializing Task #0: "fused_matmul_add13"
21 2025-02-27 10:05:40 [INFO] [task_scheduler.cc:159] Initializing Task #1: "reshape"
22 2025-02-27 10:05:40 [INFO] [task_scheduler.cc:159] Initializing Task #2: "fused_conv2d_subtract_divide_expand_dims_multiply_expand_dims_add1_relu"
23 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #3: "adaptive_avg_pool2d"
24 [10:05:41] /home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/apply_custom_rule.cc:56: Warning: Unknown schedule rule "meta_schedule.adaptive_pool_avg" for target keys "["cuda", "gpu"]". Checked PackedFuncs:
25 meta_schedule.cuda.meta_schedule.adaptive_pool_avg
26 meta_schedule.gpu.meta_schedule.adaptive_pool_avg
27 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #4: "fused_conv2d8_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4"
28 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #5: "fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4"
29 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #6: "max_pool2d"
30 [10:05:41] /home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/apply_custom_rule.cc:56: Warning: Unknown schedule rule "meta_schedule.pool_max" for target keys "["cuda", "gpu"]". Checked PackedFuncs:
31 meta_schedule.cuda.meta_schedule.pool_max
32 meta_schedule.gpu.meta_schedule.pool_max
33 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #7: "fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2"
34 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #8: "fused_conv2d10_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11"
35 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #9: "fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3"
36 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #10: "transpose"
37 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #11: "fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1"
38 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #12: "fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1"
39 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #13: "fused_conv2d7_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8"
40 2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #14: "fused_conv2d2_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2"
41 2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #15: "fused_conv2d4_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5"
42 2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #16: "fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2"
43 2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #17: "fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4"
44 2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #18: "fused_conv2d5_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3"
45 2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #19: "fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3"
46 2025-02-27 10:05:42 [INFO] [task_scheduler.cc:320]
**... output tables from multiple trail runs ...**
664 [10:05:48] /home/rbohl/work/ai_tools/tvm/src/relax/transform/meta_schedule.cc:119: Warning: Creating JSONDatabase. Workload at: tuning_logs/database_workload.json, Tuning records at: tuning_logs/database_tuning_record.json
665 # from tvm.script import relax as R
666
667 @R.function
668 def main(x: R.Tensor((1, 3, 224, 224), dtype="float32"), p_conv1_weight: R.Tensor((64, 3, 7, 7), dtype="float32"), p_bn1_weight: R.Tensor((64,), dtype="float32"), p_bn1_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___0___bn1_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___bn1_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___0___bn2_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___bn2_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___1___bn1_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___bn1_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___1___bn2_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___bn2_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer2___0___conv1_weight: R.Tensor((128, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer2___0___bn1_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___bn1_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer2___0___bn2_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___bn2_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___downsample_0_weight: R.Tensor((128, 64, 1, 1), dtype="float32"), p_getattr_l__self___layer2___0___downsample_1_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___downsample_1_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___conv1_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer2___1___bn1_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___bn1_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer2___1___bn2_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___bn2_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer3___0___conv1_weight: R.Tensor((256, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer3___0___bn1_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___bn1_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer3___0___bn2_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___bn2_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___downsample_0_weight: R.Tensor((256, 128, 1, 1), dtype="float32"), p_getattr_l__self___layer3___0___downsample_1_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___downsample_1_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___conv1_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer3___1___bn1_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___bn1_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer3___1___bn2_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___bn2_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer4___0___conv1_weight: R.Tensor((512, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer4___0___bn1_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___bn1_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_getattr_l__self___layer4___0___bn2_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___bn2_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"), p_getattr_l__self___layer4___0___downsample_1_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___downsample_1_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_getattr_l__self___layer4___1___bn1_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___bn1_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_getattr_l__self___layer4___1___bn2_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___bn2_bias: R.Tensor((512,), dtype="float32"), p_fc_weight: R.Tensor((1000, 512), dtype="float32"), p_fc_bias: R.Tensor((1000,), dtype="float32")) -> R.Tuple(R.Tensor((1, 1000), dtype="float32")):
669 R.func_attr({"num_input": 1})
670 with R.dataflow():
671 lv = R.call_tir(fused_conv2d_subtract_divide_expand_dims_multiply_expand_dims_add1_relu, (x, p_conv1_weight, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], p_bn1_weight, p_bn1_bias), out_sinfo=R.Tensor((1, 64, 112, 112), dtype="float32"))
672 lv4 = R.call_tir(max_pool2d, (lv,), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
673 lv1 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1, (lv4, p_getattr_l__self___layer1___0___conv1_weight, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], p_getattr_l__self___layer1___0___bn1_weight, p_getattr_l__self___layer1___0___bn1_bias), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
674 lv2 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1, (lv1, p_getattr_l__self___layer1___0___conv2_weight, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5], p_getattr_l__self___layer1___0___bn2_weight, p_getattr_l__self___layer1___0___bn2_bias, lv4), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
675 lv3 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1, (lv2, p_getattr_l__self___layer1___1___conv1_weight, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], p_getattr_l__self___layer1___1___bn1_weight, p_getattr_l__self___layer1___1___bn1_bias), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
676 lv4_1 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1, (lv3, p_getattr_l__self___layer1___1___conv2_weight, metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], p_getattr_l__self___layer1___1___bn2_weight, p_getattr_l__self___layer1___1___bn2_bias, lv2), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
677 lv5 = R.call_tir(fused_conv2d2_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2, (lv4_1, p_getattr_l__self___layer2___0___conv1_weight, metadata["relax.expr.Constant"][10], metadata["relax.expr.Constant"][11], p_getattr_l__self___layer2___0___bn1_weight, p_getattr_l__self___layer2___0___bn1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
678 lv6 = R.call_tir(fused_conv2d4_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5, (lv4_1, p_getattr_l__self___layer2___0___downsample_0_weight, metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], p_getattr_l__self___layer2___0___downsample_1_weight, p_getattr_l__self___layer2___0___downsample_1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
679 lv7 = R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2, (lv5, p_getattr_l__self___layer2___0___conv2_weight, metadata["relax.expr.Constant"][14], metadata["relax.expr.Constant"][15], p_getattr_l__self___layer2___0___bn2_weight, p_getattr_l__self___layer2___0___bn2_bias, lv6), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
680 lv8 = R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2, (lv7, p_getattr_l__self___layer2___1___conv1_weight, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], p_getattr_l__self___layer2___1___bn1_weight, p_getattr_l__self___layer2___1___bn1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
681 lv9 = R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2, (lv8, p_getattr_l__self___layer2___1___conv2_weight, metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], p_getattr_l__self___layer2___1___bn2_weight, p_getattr_l__self___layer2___1___bn2_bias, lv7), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
682 lv10 = R.call_tir(fused_conv2d5_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3, (lv9, p_getattr_l__self___layer3___0___conv1_weight, metadata["relax.expr.Constant"][20], metadata["relax.expr.Constant"][21], p_getattr_l__self___layer3___0___bn1_weight, p_getattr_l__self___layer3___0___bn1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
683 lv11 = R.call_tir(fused_conv2d7_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8, (lv9, p_getattr_l__self___layer3___0___downsample_0_weight, metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], p_getattr_l__self___layer3___0___downsample_1_weight, p_getattr_l__self___layer3___0___downsample_1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
684 lv12 = R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3, (lv10, p_getattr_l__self___layer3___0___conv2_weight, metadata["relax.expr.Constant"][24], metadata["relax.expr.Constant"][25], p_getattr_l__self___layer3___0___bn2_weight, p_getattr_l__self___layer3___0___bn2_bias, lv11), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
685 lv13 = R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3, (lv12, p_getattr_l__self___layer3___1___conv1_weight, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], p_getattr_l__self___layer3___1___bn1_weight, p_getattr_l__self___layer3___1___bn1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
686 lv14 = R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3, (lv13, p_getattr_l__self___layer3___1___conv2_weight, metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], p_getattr_l__self___layer3___1___bn2_weight, p_getattr_l__self___layer3___1___bn2_bias, lv12), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
687 lv15 = R.call_tir(fused_conv2d8_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4, (lv14, p_getattr_l__self___layer4___0___conv1_weight, metadata["relax.expr.Constant"][30], metadata["relax.expr.Constant"][31], p_getattr_l__self___layer4___0___bn1_weight, p_getattr_l__self___layer4___0___bn1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
688 lv16 = R.call_tir(fused_conv2d10_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11, (lv14, p_getattr_l__self___layer4___0___downsample_0_weight, metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], p_getattr_l__self___layer4___0___downsample_1_weight, p_getattr_l__self___layer4___0___downsample_1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
689 lv17 = R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4, (lv15, p_getattr_l__self___layer4___0___conv2_weight, metadata["relax.expr.Constant"][34], metadata["relax.expr.Constant"][35], p_getattr_l__self___layer4___0___bn2_weight, p_getattr_l__self___layer4___0___bn2_bias, lv16), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
690 lv18 = R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4, (lv17, p_getattr_l__self___layer4___1___conv1_weight, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], p_getattr_l__self___layer4___1___bn1_weight, p_getattr_l__self___layer4___1___bn1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
691 lv19 = R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4, (lv18, p_getattr_l__self___layer4___1___conv2_weight, metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], p_getattr_l__self___layer4___1___bn2_weight, p_getattr_l__self___layer4___1___bn2_bias, lv17), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
692 lv86 = R.call_tir(adaptive_avg_pool2d, (lv19,), out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"))
693 lv87 = R.call_tir(reshape, (lv86,), out_sinfo=R.Tensor((1, 512), dtype="float32"))
694 lv88 = R.call_tir(transpose, (p_fc_weight,), out_sinfo=R.Tensor((512, 1000), dtype="float32"))
695 lv20 = R.call_tir(fused_matmul_add13, (lv87, lv88, p_fc_bias), out_sinfo=R.Tensor((1, 1000), dtype="float32"))
696 gv: R.Tuple(R.Tensor((1, 1000), dtype="float32")) = (lv20,)
697 R.output(gv)
698 return gv
699
700 # Metadata omitted. Use show_meta=True in script() method to show it.
701
702 Traceback (most recent call last):
703 File "/home/rbohl/work/ai_tools/tvm/e2e_debug.py", line 85, in <module>
704 ex = relax.build(mod, target="cuda")
705 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
706 File "/home/rbohl/work/ai_tools/tvm/python/tvm/relax/vm_build.py", line 352, in build
707 return _vmlink(
708 ^^^^^^^^
709 File "/home/rbohl/work/ai_tools/tvm/python/tvm/relax/vm_build.py", line 252, in _vmlink
710 lib = tvm.build(tir_mod, target=target)
711 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
712 File "/home/rbohl/work/ai_tools/tvm/python/tvm/driver/build_module.py", line 145, in build
713 rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
714 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
715 File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
716 File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
717 File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
718 File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
719 File "/home/rbohl/work/ai_tools/tvm/python/tvm/_ffi/base.py", line 465, in raise_last_ffi_error
720 raise py_err
721 File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 445, in operator()
722 return TIRToRuntime(inputs_arg, host_target);
723 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
724 File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 406, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
725 auto pair = SplitMixedModule(ir_module, target, target_host);
726 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
727 File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 333, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
728 mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
729 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
730 File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 287, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
731 mod = seq(std::move(mod));
732 ^^^^^^^^^^^
733 File "/home/rbohl/work/ai_tools/tvm/src/tir/analysis/verify_memory.cc", line 203, in operator()
734 LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
735 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
736 tvm._ffi.base.TVMError: Traceback (most recent call last):
737 4: operator()
738 at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:445
739 3: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
740 at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:406
741 2: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
742 at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:333
743 1: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
744 at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:287
745 0: operator()
746 at /home/rbohl/work/ai_tools/tvm/src/tir/analysis/verify_memory.cc:203
747 Did you forget to bind?
748 Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
749 Variable `lv4` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
750 Variable `p_getattr_l__self___layer1___0___bn2_bias` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
751 Variable `p_getattr_l__self___layer1___0___bn2_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
752 Variable `param_1` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
753 Variable `param_0` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
754 Variable `p_getattr_l__self___layer1___0___conv2_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
755 Variable `lv8` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
756 File "/home/rbohl/work/ai_tools/tvm/src/tir/analysis/verify_memory.cc", line 203
757 RuntimeError: Memory verification failed with the following errors:
758 # from tvm.script import tir as T
759
760 @T.prim_func
761 def fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1(lv8: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), p_getattr_l__self___layer1___0___conv2_weight: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"), param_0: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"), param_1: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"), p_getattr_l__self___layer1___0___bn2_weight: T.Buffer((T.int64(64),), "float32"), p_getattr_l__self___layer1___0___bn2_bias: T.Buffer((T.int64(64),), "float32"), lv4: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32")):
762 T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-conda-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
763 pad_temp = T.allocate([215296], "float32", "global")
764 conv2d_nchw = T.allocate([200704], "float32", "global")
765 expand_dims = T.allocate([64], "float32", "global")
766 pad_temp_1 = T.Buffer((T.int64(215296),), data=pad_temp)
767 for i1, i2, i3 in T.grid(64, 58, 58):
768 lv8_1 = T.Buffer((T.int64(200704),), data=lv8.data)
769 pad_temp_1[i1 * 3364 + i2 * 58 + i3] = T.if_then_else(1 <= i2 and i2 < 57 and 1 <= i3 and i3 < 57, lv8_1[i1 * 3136 + i2 * 56 + i3 - 57], T.float32(0.0))
770 conv2d_nchw_1 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
771 for ff, yy, xx, rc, ry, rx in T.grid(64, 56, 56, 64, 3, 3):
772 cse_var_1: T.int32 = ff * 3136 + yy * 56 + xx
773 if rc == 0 and ry == 0 and rx == 0:
774 conv2d_nchw_1[cse_var_1] = T.float32(0.0)
775 p_getattr_l__self___layer1___0___conv2_weight_1 = T.Buffer((T.int64(36864),), data=p_getattr_l__self___layer1___0___conv2_weight.data)
776 conv2d_nchw_1[cse_var_1] = conv2d_nchw_1[cse_var_1] + pad_temp_1[rc * 3364 + yy * 58 + ry * 58 + xx + rx] * p_getattr_l__self___layer1___0___conv2_weight_1[ff * 576 + rc * 9 + ry * 3 + rx]
777 conv2d_nchw_2 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
778 for ax1, ax2, ax3 in T.grid(64, 56, 56):
779 cse_var_2: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
780 param_0_1 = T.Buffer((T.int64(64),), data=param_0.data)
781 conv2d_nchw_2[cse_var_2] = conv2d_nchw_1[cse_var_2] - param_0_1[ax1]
782 conv2d_nchw_3 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
783 for ax1, ax2, ax3 in T.grid(64, 56, 56):
784 cse_var_3: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
785 param_1_1 = T.Buffer((T.int64(64),), data=param_1.data)
786 conv2d_nchw_3[cse_var_3] = conv2d_nchw_2[cse_var_3] / param_1_1[ax1]
787 expand_dims_1 = T.Buffer((T.int64(64),), data=expand_dims)
788 for i1 in range(64):
789 p_getattr_l__self___layer1___0___bn2_weight_1 = T.Buffer((T.int64(64),), data=p_getattr_l__self___layer1___0___bn2_weight.data)
790 expand_dims_1[i1] = p_getattr_l__self___layer1___0___bn2_weight_1[i1]
791 conv2d_nchw_4 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
792 for ax1, ax2, ax3 in T.grid(64, 56, 56):
793 cse_var_4: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
794 conv2d_nchw_4[cse_var_4] = conv2d_nchw_3[cse_var_4] * expand_dims_1[ax1]
795 expand_dims_2 = T.Buffer((T.int64(64),), data=expand_dims)
796 for i1 in range(64):
797 p_getattr_l__self___layer1___0___bn2_bias_1 = T.Buffer((T.int64(64),), data=p_getattr_l__self___layer1___0___bn2_bias.data)
798 expand_dims_2[i1] = p_getattr_l__self___layer1___0___bn2_bias_1[i1]
799 conv2d_nchw_5 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
800 for ax1, ax2, ax3 in T.grid(64, 56, 56):
801 cse_var_5: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
802 conv2d_nchw_5[cse_var_5] = conv2d_nchw_4[cse_var_5] + expand_dims_2[ax1]
803 conv2d_nchw_6 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
804 for ax1, ax2, ax3 in T.grid(64, 56, 56):
805 cse_var_6: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
806 lv4_1 = T.Buffer((T.int64(200704),), data=lv4.data)
807 conv2d_nchw_6[cse_var_6] = conv2d_nchw_5[cse_var_6] + lv4_1[cse_var_6]
808 for i1, i2, i3 in T.grid(64, 56, 56):
809 cse_var_7: T.int32 = i1 * 3136 + i2 * 56 + i3
810 compute_intermediate_1 = T.Buffer((T.int64(200704),), data=compute_intermediate.data)
811 compute_intermediate_1[cse_var_7] = T.max(conv2d_nchw_6[cse_var_7], T.float32(0.0))