Value Error regarding undefined max_threads_per_block

Hello,

Using TVM built from source in a micromamba environment on Ubuntu Linux 24.04, I’m getting the following error when using a RTX 4090 GPU:

ValueError: Check failed: (max_threads_per_block.defined()) is false: missing attribute `max_threads_per_block` in the target

at the following expression:

mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=
TOTAL_TRIALS)(mod)

The script I’m running is a slightly modified version of the e2e_opt_model.py example at https://tvm.apache.org/docs/how_to/tutorials/e2e_opt_model.html. Note that the debug print of the max_num_threads attribute of the Target succeeds and prints the value 1024 on line 58 of the script.

Any guidance on what I have incorrect is greatly appreciated.

Here is the script:

     1	# Licensed to the Apache Software Foundation (ASF) under one
     2	# or more contributor license agreements.  See the NOTICE file
     3	# distributed with this work for additional information
     4	# regarding copyright ownership.  The ASF licenses this file
     5	# to you under the Apache License, Version 2.0 (the
     6	# "License"); you may not use this file except in compliance
     7	# with the License.  You may obtain a copy of the License at
     8	#
     9	#   http://www.apache.org/licenses/LICENSE-2.0
    10	#
    11	# Unless required by applicable law or agreed to in writing,
    12	# software distributed under the License is distributed on an
    13	# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    14	# KIND, either express or implied.  See the License for the
    15	# specific language governing permissions and limitations
    16	# under the License.
    17	
    18	import os
    19	import numpy as np
    20	import torch
    21	from torch.export import export
    22	from torchvision.models.resnet import ResNet18_Weights, resnet18
    23	
    24	import tvm
    25	from tvm import relax
    26	from tvm.relax.frontend.torch import from_exported_program
    27	
    28	print("\n##### TVM and CUDA Checks #####")
    29	print(f"TVM File: {tvm.__file__}")
    30	print(f"TVM base lib: {tvm._ffi.base._LIB}")
    31	print(f"Cuda exists?: {tvm.cuda().exist}")
    32	print("################################\n")
    33	
    34	torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
    35	
    36	# Give an example argument to torch.export
    37	example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
    38	
    39	# Convert the model to IRModule
    40	with torch.no_grad():
    41	    exported_program = export(torch_model, example_args)
    42	    mod = from_exported_program(exported_program, keep_params_as_input=True)
    43	
    44	mod, params = relax.frontend.detach_params(mod)
    45	# mod.show()
    46	
    47	TOTAL_TRIALS = 1  # Change to 20000 for better performance if needed
    48	
    49	# target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")  # Change to your target device
    50	# target = tvm.target.Target("nvidia/geforce-rtx-4090", host="cuda")  # Change to your target device
    51	
    52	target = tvm.target.cuda(model="4090", arch="sm_89");
    53	
    54	print("\n##### Target Info #####")
    55	print(f"Device Type: {target.get_target_device_type()}")
    56	print(f"Model: {target.model}")
    57	print(f"Arch: {target.arch}")
    58	print(f"Max Threads: {target.max_num_threads}")
    59	# The following two lines generate the error show below
    60	# print(f"Max Block Size x: {target.max_block_size_x}")
    61	# print(f"Max Block Size y: {target.max_block_size_y}")
    62	#   Generates error:
    63	#    return int(self.attrs["max_block_size_x"])
    64	#               ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
    65	#    InternalError: Check failed: (it != n->end()) is false: cannot find the corresponding key in the Map
    66	#
    67	print(f"Thread Warp Size: {target.thread_warp_size}")
    68	print("#########################\n")
    69	
    70	work_dir = "tuning_logs"
    71	
    72	# Skip running in CI environment
    73	IS_IN_CI = os.getenv("CI", "") == "true"
    74	if not IS_IN_CI:
    75	    mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod)
    76	
    77	    # Only show the main function
    78	    mod["main"].show()
    79	
    80	if not IS_IN_CI:
    81	    ex = relax.build(mod, target="cuda")
    82	    dev = tvm.device("cuda", 0)
    83	    vm = relax.VirtualMachine(ex, dev)
    84	    # Need to allocate data and params on GPU device
    85	    gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
    86	    gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
    87	    gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
    88	
    89	    print(gpu_out.shape)

Here is the script output:

     2	##### TVM and CUDA Checks #####
     3	TVM File: /home/rbohl/work/ai_tools/tvm/python/tvm/__init__.py
     4	TVM base lib: <CDLL '/home/rbohl/work/ai_tools/tvm/build/libtvm.so', handle 5eaf87031940 at 0x7d97b4382610>
     5	Cuda exists?: True
     6	################################
     7	
     8	
     9	##### Target Info #####
    10	Device Type: 2
    11	Model: 4090
    12	Arch: sm_89
    13	Max Threads: 1024
    14	Thread Warp Size: 32
    15	#########################
    16	
    17	2025-02-19 10:50:21 [INFO] Logging directory: tuning_logs/logs
    18	Traceback (most recent call last):
    19	  File "/home/rbohl/work/ai_tools/tvm/e2e_debug.py", line 75, in <module>
    20	    mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod)
    21	          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    22	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/ir/transform.py", line 238, in __call__
    23	    return _ffi_transform_api.RunPass(self, mod)
    24	           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    25	  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
    26	  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
    27	  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
    28	  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
    29	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/_ffi/base.py", line 465, in raise_last_ffi_error
    30	    raise py_err
    31	  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
    32	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/relax/pipeline.py", line 172, in _pipeline
    33	    mod = tvm.transform.Sequential(
    34	          ^^^^^^^^^^^^^^^^^^^^^^^^^
    35	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/ir/transform.py", line 238, in __call__
    36	    return _ffi_transform_api.RunPass(self, mod)
    37	           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    38	  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
    39	  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
    40	  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
    41	  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
    42	  File "/home/rbohl/work/ai_tools/tvm/src/relax/transform/meta_schedule.cc", line 187, in operator()
    43	    .TuneIRMod(m, ctx);
    44	              ^^^^^^^^^^
    45	  File "/home/rbohl/work/ai_tools/tvm/src/relax/transform/meta_schedule.cc", line 60, in tvm::relax::transform::MetaScheduleTuner::TuneIRMod(tvm::IRModule, tvm::transform::PassContext)
    46	    knob->Apply(mod, "0");
    47	                  ^^^^^^^^^
    48	  File "/home/rbohl/work/ai_tools/tvm/include/tvm/relax/tuning_api.h", line 149, in tvm::relax::KnobNode::Apply(tvm::IRModule, tvm::runtime::String)
    49	    return choices[decision]->ApplyTransformFunc(mod);
    50	                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    51	  File "/home/rbohl/work/ai_tools/tvm/include/tvm/relax/tuning_api.h", line 100, in tvm::relax::ChoiceNode::ApplyTransformFunc(tvm::IRModule)
    52	    return CallPackedWithArgsInArray(GetTransformFunc(), args);
    53	                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    54	  File "/home/rbohl/work/ai_tools/tvm/include/tvm/relax/tuning_api.h", line 48, in tvm::relax::CallPackedWithArgsInArray(tvm::runtime::PackedFunc, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&)
    55	    f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv);
    56	                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    57	  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
    58	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/meta_schedule/relax_integration.py", line 354, in _tune_relax
    59	    tune_relax(
    60	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/meta_schedule/relax_integration.py", line 248, in tune_relax
    61	    tasks, task_weights = extracted_tasks_to_tune_contexts(
    62	                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    63	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/meta_schedule/relax_integration.py", line 145, in extracted_tasks_to_tune_contexts
    64	    TuneContext(
    65	  File "/home/rbohl/work/ai_tools/tvm/python/tvm/meta_schedule/tune_context.py", line 150, in __init__
    66	    _ffi_api.TuneContextInitialize(self)  # type: ignore # pylint: disable=no-member
    67	    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    68	  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
    69	  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
    70	  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
    71	  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
    72	  File "/home/rbohl/work/ai_tools/tvm/src/meta_schedule/tune_context.cc", line 58, in tvm::meta_schedule::TuneContextNode::Initialize()
    73	    this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
    74	                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    75	  File "/home/rbohl/work/ai_tools/tvm/src/meta_schedule/space_generator/post_order_apply.cc", line 44, in tvm::meta_schedule::PostOrderApplyNode::InitializeWithTuneContext(tvm::meta_schedule::TuneContext const&)
    76	    SpaceGeneratorNode::InitializeWithTuneContext(context);
    77	                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    78	  File "/home/rbohl/work/ai_tools/tvm/src/meta_schedule/space_generator/space_generator.cc", line 144, in tvm::meta_schedule::SpaceGeneratorNode::InitializeWithTuneContext(tvm::meta_schedule::TuneContext const&)
    79	    i->InitializeWithTuneContext(context);
    80	                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    81	  File "/home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/auto_bind.cc", line 36, in tvm::meta_schedule::AutoBindNode::InitializeWithTuneContext(tvm::meta_schedule::TuneContext const&)
    82	    CHECK(max_threads_per_block.defined())
    83	                  ^^^^^^^^^^^^^^^^^^^^^^^^^
    84	ValueError: Traceback (most recent call last):
    85	  3: tvm::meta_schedule::TuneContextNode::Initialize()
    86	        at /home/rbohl/work/ai_tools/tvm/src/meta_schedule/tune_context.cc:58
    87	  2: tvm::meta_schedule::PostOrderApplyNode::InitializeWithTuneContext(tvm::meta_schedule::TuneContext const&)
    88	        at /home/rbohl/work/ai_tools/tvm/src/meta_schedule/space_generator/post_order_apply.cc:44
    89	  1: tvm::meta_schedule::SpaceGeneratorNode::InitializeWithTuneContext(tvm::meta_schedule::TuneContext const&)
    90	        at /home/rbohl/work/ai_tools/tvm/src/meta_schedule/space_generator/space_generator.cc:144
    91	  0: tvm::meta_schedule::AutoBindNode::InitializeWithTuneContext(tvm::meta_schedule::TuneContext const&)
    92	        at /home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/auto_bind.cc:36
    93	  File "/home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/auto_bind.cc", line 36
    94	ValueError: Check failed: (max_threads_per_block.defined()) is false: missing attribute `max_threads_per_block` in the target

Hello,

Any ideas as to what is wrong ? The example script should surely work with a RTX 4090 target.

As an alternative, specifying my RTX 4090 target in the pattern of the example’s original RTX 3090 target using the Target constructor rather than the tvm.target.cuda() function results in the following error:

RuntimeError: Memory verification failed with the following errors:
 ...
Did you forget to bind?
Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
... several more variables listed for the same problem

at the following statement on line 56 of the script:

target = tvm.target.Target("nvidia/geforce-rtx-4090", host="cuda")  # Change to your target device

So, I’m unable to execute the end-to-end example when specifying a RTX 4090 target. Should I file a bug ?

I built TVM from the latest source using the variable settings from tvm.support.libinfo() below. A listing of the script and the corresponding output follow the variable settings.

USE_NVTX: OFF
USE_GTEST: AUTO
SUMMARIZE: OFF
TVM_DEBUG_WITH_ABI_CHANGE: OFF
USE_IOS_RPC: OFF
USE_MSC: OFF
CUDA_VERSION: 12.8
USE_LIBBACKTRACE: AUTO
DLPACK_PATH: 3rdparty/dlpack/include
USE_TENSORRT_CODEGEN: OFF
USE_OPENCL_EXTN_QCOM: NOT-FOUND
USE_THRUST: OFF
BUILD_DUMMY_LIBTVM: OFF
USE_CUDNN: ON
USE_TENSORRT_RUNTIME: OFF
USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR: OFF
USE_CCACHE: AUTO
USE_ARM_COMPUTE_LIB: OFF
USE_CPP_RTVM: OFF
USE_OPENCL_GTEST: /path/to/opencl/gtest
TVM_LOG_BEFORE_THROW: OFF
USE_MKL: OFF
MLIR_VERSION: NOT-FOUND
USE_CLML: OFF
USE_STACKVM_RUNTIME: OFF
ROCM_PATH: /opt/rocm
USE_DNNL: OFF
USE_MSCCL: OFF
USE_NNAPI_RUNTIME: OFF
USE_MLIR: OFF
USE_RCCL: OFF
USE_LLVM: llvm-config --ignore-libllvm --link-static
USE_THREADS: ON
USE_MSVC_MT: OFF
BACKTRACE_ON_SEGFAULT: OFF
USE_ROCBLAS: OFF
USE_NCCL: OFF
GIT_COMMIT_HASH: 0d42dc411d8729a97a7d6cc89266da2b50045897
USE_VULKAN: OFF
USE_RUST_EXT: OFF
USE_CUTLASS: OFF
USE_CPP_RPC: OFF
USE_HEXAGON: OFF
USE_CUSTOM_LOGGING: OFF
USE_UMA: OFF
USE_FALLBACK_STL_MAP: OFF
USE_SORT: ON
USE_RTTI: ON
GIT_COMMIT_TIME: 2025-02-18 10:30:10 -0500
USE_HIPBLAS: OFF
USE_HEXAGON_SDK: /path/to/sdk
USE_BLAS: none
USE_LIBTORCH: OFF
USE_RANDOM: ON
USE_CUDA: ON
USE_COREML: OFF
USE_AMX: OFF
BUILD_STATIC_RUNTIME: OFF
USE_KHRONOS_SPIRV: OFF
USE_CLML_GRAPH_EXECUTOR: OFF
USE_TFLITE: OFF
USE_HEXAGON_GTEST: /path/to/hexagon/gtest
PICOJSON_PATH: 3rdparty/picojson
USE_OPENCL_ENABLE_HOST_PTR: OFF
INSTALL_DEV: OFF
USE_NNPACK: OFF
LLVM_VERSION: 19.1.7
USE_MRVL: OFF
USE_OPENCL: OFF
COMPILER_RT_PATH: 3rdparty/compiler-rt
USE_NNAPI_CODEGEN: OFF
RANG_PATH: 3rdparty/rang/include
USE_SPIRV_KHR_INTEGER_DOT_PRODUCT: OFF
USE_OPENMP: none
USE_BNNS: OFF
USE_FLASHINFER: OFF
USE_CUBLAS: ON
USE_METAL: OFF
USE_HEXAGON_EXTERNAL_LIBS: OFF
USE_ALTERNATIVE_LINKER: AUTO
USE_BYODT_POSIT: OFF
USE_NVSHMEM: OFF
USE_HEXAGON_RPC: OFF
DMLC_PATH: 3rdparty/dmlc-core/include
INDEX_DEFAULT_I64: ON
USE_RPC: ON
USE_TENSORFLOW_PATH: none
TVM_CLML_VERSION: 
USE_MIOPEN: OFF
USE_ROCM: OFF
USE_PAPI: OFF
USE_CURAND: OFF
TVM_CXX_COMPILER_PATH: /usr/bin/c++
HIDE_PRIVATE_SYMBOLS: ON

Here is the script:

     1  # Licensed to the Apache Software Foundation (ASF) under one
     2  # or more contributor license agreements.  See the NOTICE file
     3  # distributed with this work for additional information
     4  # regarding copyright ownership.  The ASF licenses this file
     5  # to you under the Apache License, Version 2.0 (the
     6  # "License"); you may not use this file except in compliance
     7  # with the License.  You may obtain a copy of the License at
     8  #
     9  #   http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing,
    12  # software distributed under the License is distributed on an
    13  # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    14  # KIND, either express or implied.  See the License for the
    15  # specific language governing permissions and limitations
    16  # under the License.
    17
    18  import os
    19  import numpy as np
    20  import torch
    21  from torch.export import export
    22  from torchvision.models.resnet import ResNet18_Weights, resnet18
    23
    24  import tvm
    25  from tvm import relax
    26  from tvm.relax.frontend.torch import from_exported_program
    27
    28  print("\n##### TVM and CUDA Checks #####")
    29  print(f"TVM File: {tvm.__file__}")
    30  print(f"TVM base lib: {tvm._ffi.base._LIB}")
    31  print(f"Cuda exists?: {tvm.cuda().exist}")
    32  print("################################\n")
    33
    34  torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
    35
    36  # Give an example argument to torch.export
    37  example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
    38
    39  # Convert the model to IRModule
    40  with torch.no_grad():
    41      exported_program = export(torch_model, example_args)
    42      mod = from_exported_program(exported_program, keep_params_as_input=True)
    43
    44  mod, params = relax.frontend.detach_params(mod)
    45  # mod.show()
    46
    47  TOTAL_TRIALS = 1  # Change to 20000 for better performance if needed
    48
    49  # target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")  # Change to your target device
    50
    51  # The following target specification generates an RuntimeError
    52  # RuntimeError: Memory verification failed with the following errors:
    53  # ...
    54  # Did you forget to bind?
    55  #    Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    56  target = tvm.target.Target("nvidia/geforce-rtx-4090", host="cuda")  # Change to your target device
    57
    58  # The following target specification generates a ValueError
    59  # ValueError: Check failed: (max_threads_per_block.defined()) is false: missing attribute `max_threads_per_block` in the target
    60  # target = tvm.target.cuda(model="4090", arch="sm_89");
    61
    62  print("\n##### Target Info #####")
    63  print(f"Device Type: {target.get_target_device_type()}")
    64  print(f"Model: {target.model}")
    65  print(f"Arch: {target.arch}")
    66  print(f"Max Threads: {target.max_num_threads}")
    67  # The following two lines generate the error show below
    68  # print(f"Max Block Size x: {target.max_block_size_x}")
    69  # print(f"Max Block Size y: {target.max_block_size_y}")
    70  #   Generates error:
    71  #    return int(self.attrs["max_block_size_x"])
    72  #               ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
    73  #    InternalError: Check failed: (it != n->end()) is false: cannot find the corresponding key in the Map
    74  #
    75  print(f"Thread Warp Size: {target.thread_warp_size}")
    76  print("#########################\n")
    77
    78  work_dir = "tuning_logs"
    79
    80  # Skip running in CI environment
    81  IS_IN_CI = os.getenv("CI", "") == "true"
    82  if not IS_IN_CI:
    83      mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod)
    84
    85      # Only show the main function
    86      mod["main"].show()
    87
    88  if not IS_IN_CI:
    89      ex = relax.build(mod, target="cuda")
    90      dev = tvm.device("cuda", 0)
    91      vm = relax.VirtualMachine(ex, dev)
    92      # Need to allocate data and params on GPU device
    93      gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
    94      gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
    95      gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
    96
    97      print(gpu_out.shape)

Here is the excerpted output to fit under the maximum posting size:

     1
     2  ##### TVM and CUDA Checks #####
     3  TVM File: /home/rbohl/work/ai_tools/tvm/python/tvm/__init__.py
     4  TVM base lib: <CDLL '/home/rbohl/work/ai_tools/tvm/build/libtvm.so', handle 63be0f4cda10 at 0x74401d268310>
     5  Cuda exists?: True
     6  ################################
     7
     8
     9  ##### Target Info #####
    10  Device Type: 2
    11  Model: unknown
    12  Arch: sm_89
    13  Max Threads: 1024
    14  Thread Warp Size: 32
    15  #########################
    16
    17  2025-02-27 10:05:28 [INFO] Logging directory: tuning_logs/logs
    18  2025-02-27 10:05:35 [INFO] LocalBuilder: max_workers = 8
    19  2025-02-27 10:05:36 [INFO] LocalRunner: max_workers = 1
    20  2025-02-27 10:05:40 [INFO] [task_scheduler.cc:159] Initializing Task #0: "fused_matmul_add13"
    21  2025-02-27 10:05:40 [INFO] [task_scheduler.cc:159] Initializing Task #1: "reshape"
    22  2025-02-27 10:05:40 [INFO] [task_scheduler.cc:159] Initializing Task #2: "fused_conv2d_subtract_divide_expand_dims_multiply_expand_dims_add1_relu"
    23  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #3: "adaptive_avg_pool2d"
    24  [10:05:41] /home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/apply_custom_rule.cc:56: Warning: Unknown schedule rule "meta_schedule.adaptive_pool_avg" for target keys "["cuda", "gpu"]". Checked PackedFuncs:
    25    meta_schedule.cuda.meta_schedule.adaptive_pool_avg
    26    meta_schedule.gpu.meta_schedule.adaptive_pool_avg
    27  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #4: "fused_conv2d8_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4"
    28  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #5: "fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4"
    29  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #6: "max_pool2d"
    30  [10:05:41] /home/rbohl/work/ai_tools/tvm/src/meta_schedule/schedule_rule/apply_custom_rule.cc:56: Warning: Unknown schedule rule "meta_schedule.pool_max" for target keys "["cuda", "gpu"]". Checked PackedFuncs:
    31    meta_schedule.cuda.meta_schedule.pool_max
    32    meta_schedule.gpu.meta_schedule.pool_max
    33  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #7: "fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2"
    34  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #8: "fused_conv2d10_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11"
    35  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #9: "fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3"
    36  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #10: "transpose"
    37  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #11: "fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1"
    38  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #12: "fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1"
    39  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #13: "fused_conv2d7_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8"
    40  2025-02-27 10:05:41 [INFO] [task_scheduler.cc:159] Initializing Task #14: "fused_conv2d2_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2"
    41  2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #15: "fused_conv2d4_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5"
    42  2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #16: "fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2"
    43  2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #17: "fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4"
    44  2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #18: "fused_conv2d5_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3"
    45  2025-02-27 10:05:42 [INFO] [task_scheduler.cc:159] Initializing Task #19: "fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3"
    46  2025-02-27 10:05:42 [INFO] [task_scheduler.cc:320] 

 **... output tables from multiple trail runs ...**

   664  [10:05:48] /home/rbohl/work/ai_tools/tvm/src/relax/transform/meta_schedule.cc:119: Warning: Creating JSONDatabase. Workload at: tuning_logs/database_workload.json, Tuning records at: tuning_logs/database_tuning_record.json
   665  # from tvm.script import relax as R
   666
   667  @R.function
   668  def main(x: R.Tensor((1, 3, 224, 224), dtype="float32"), p_conv1_weight: R.Tensor((64, 3, 7, 7), dtype="float32"), p_bn1_weight: R.Tensor((64,), dtype="float32"), p_bn1_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___0___bn1_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___bn1_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___0___bn2_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___0___bn2_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___1___bn1_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___bn1_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer1___1___bn2_weight: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer1___1___bn2_bias: R.Tensor((64,), dtype="float32"), p_getattr_l__self___layer2___0___conv1_weight: R.Tensor((128, 64, 3, 3), dtype="float32"), p_getattr_l__self___layer2___0___bn1_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___bn1_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer2___0___bn2_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___bn2_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___downsample_0_weight: R.Tensor((128, 64, 1, 1), dtype="float32"), p_getattr_l__self___layer2___0___downsample_1_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___0___downsample_1_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___conv1_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer2___1___bn1_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___bn1_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer2___1___bn2_weight: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer2___1___bn2_bias: R.Tensor((128,), dtype="float32"), p_getattr_l__self___layer3___0___conv1_weight: R.Tensor((256, 128, 3, 3), dtype="float32"), p_getattr_l__self___layer3___0___bn1_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___bn1_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer3___0___bn2_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___bn2_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___downsample_0_weight: R.Tensor((256, 128, 1, 1), dtype="float32"), p_getattr_l__self___layer3___0___downsample_1_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___0___downsample_1_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___conv1_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer3___1___bn1_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___bn1_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer3___1___bn2_weight: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer3___1___bn2_bias: R.Tensor((256,), dtype="float32"), p_getattr_l__self___layer4___0___conv1_weight: R.Tensor((512, 256, 3, 3), dtype="float32"), p_getattr_l__self___layer4___0___bn1_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___bn1_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_getattr_l__self___layer4___0___bn2_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___bn2_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"), p_getattr_l__self___layer4___0___downsample_1_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___0___downsample_1_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_getattr_l__self___layer4___1___bn1_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___bn1_bias: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_getattr_l__self___layer4___1___bn2_weight: R.Tensor((512,), dtype="float32"), p_getattr_l__self___layer4___1___bn2_bias: R.Tensor((512,), dtype="float32"), p_fc_weight: R.Tensor((1000, 512), dtype="float32"), p_fc_bias: R.Tensor((1000,), dtype="float32")) -> R.Tuple(R.Tensor((1, 1000), dtype="float32")):
   669      R.func_attr({"num_input": 1})
   670      with R.dataflow():
   671          lv = R.call_tir(fused_conv2d_subtract_divide_expand_dims_multiply_expand_dims_add1_relu, (x, p_conv1_weight, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], p_bn1_weight, p_bn1_bias), out_sinfo=R.Tensor((1, 64, 112, 112), dtype="float32"))
   672          lv4 = R.call_tir(max_pool2d, (lv,), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
   673          lv1 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1, (lv4, p_getattr_l__self___layer1___0___conv1_weight, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], p_getattr_l__self___layer1___0___bn1_weight, p_getattr_l__self___layer1___0___bn1_bias), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
   674          lv2 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1, (lv1, p_getattr_l__self___layer1___0___conv2_weight, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5], p_getattr_l__self___layer1___0___bn2_weight, p_getattr_l__self___layer1___0___bn2_bias, lv4), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
   675          lv3 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_relu1, (lv2, p_getattr_l__self___layer1___1___conv1_weight, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], p_getattr_l__self___layer1___1___bn1_weight, p_getattr_l__self___layer1___1___bn1_bias), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
   676          lv4_1 = R.call_tir(fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1, (lv3, p_getattr_l__self___layer1___1___conv2_weight, metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], p_getattr_l__self___layer1___1___bn2_weight, p_getattr_l__self___layer1___1___bn2_bias, lv2), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
   677          lv5 = R.call_tir(fused_conv2d2_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2, (lv4_1, p_getattr_l__self___layer2___0___conv1_weight, metadata["relax.expr.Constant"][10], metadata["relax.expr.Constant"][11], p_getattr_l__self___layer2___0___bn1_weight, p_getattr_l__self___layer2___0___bn1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
   678          lv6 = R.call_tir(fused_conv2d4_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5, (lv4_1, p_getattr_l__self___layer2___0___downsample_0_weight, metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], p_getattr_l__self___layer2___0___downsample_1_weight, p_getattr_l__self___layer2___0___downsample_1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
   679          lv7 = R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2, (lv5, p_getattr_l__self___layer2___0___conv2_weight, metadata["relax.expr.Constant"][14], metadata["relax.expr.Constant"][15], p_getattr_l__self___layer2___0___bn2_weight, p_getattr_l__self___layer2___0___bn2_bias, lv6), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
   680          lv8 = R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_relu2, (lv7, p_getattr_l__self___layer2___1___conv1_weight, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], p_getattr_l__self___layer2___1___bn1_weight, p_getattr_l__self___layer2___1___bn1_bias), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
   681          lv9 = R.call_tir(fused_conv2d3_subtract2_divide2_expand_dims1_multiply2_expand_dims1_add5_add6_relu2, (lv8, p_getattr_l__self___layer2___1___conv2_weight, metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], p_getattr_l__self___layer2___1___bn2_weight, p_getattr_l__self___layer2___1___bn2_bias, lv7), out_sinfo=R.Tensor((1, 128, 28, 28), dtype="float32"))
   682          lv10 = R.call_tir(fused_conv2d5_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3, (lv9, p_getattr_l__self___layer3___0___conv1_weight, metadata["relax.expr.Constant"][20], metadata["relax.expr.Constant"][21], p_getattr_l__self___layer3___0___bn1_weight, p_getattr_l__self___layer3___0___bn1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
   683          lv11 = R.call_tir(fused_conv2d7_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8, (lv9, p_getattr_l__self___layer3___0___downsample_0_weight, metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], p_getattr_l__self___layer3___0___downsample_1_weight, p_getattr_l__self___layer3___0___downsample_1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
   684          lv12 = R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3, (lv10, p_getattr_l__self___layer3___0___conv2_weight, metadata["relax.expr.Constant"][24], metadata["relax.expr.Constant"][25], p_getattr_l__self___layer3___0___bn2_weight, p_getattr_l__self___layer3___0___bn2_bias, lv11), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
   685          lv13 = R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_relu3, (lv12, p_getattr_l__self___layer3___1___conv1_weight, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], p_getattr_l__self___layer3___1___bn1_weight, p_getattr_l__self___layer3___1___bn1_bias), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
   686          lv14 = R.call_tir(fused_conv2d6_subtract3_divide3_expand_dims2_multiply3_expand_dims2_add8_add9_relu3, (lv13, p_getattr_l__self___layer3___1___conv2_weight, metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], p_getattr_l__self___layer3___1___bn2_weight, p_getattr_l__self___layer3___1___bn2_bias, lv12), out_sinfo=R.Tensor((1, 256, 14, 14), dtype="float32"))
   687          lv15 = R.call_tir(fused_conv2d8_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4, (lv14, p_getattr_l__self___layer4___0___conv1_weight, metadata["relax.expr.Constant"][30], metadata["relax.expr.Constant"][31], p_getattr_l__self___layer4___0___bn1_weight, p_getattr_l__self___layer4___0___bn1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
   688          lv16 = R.call_tir(fused_conv2d10_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11, (lv14, p_getattr_l__self___layer4___0___downsample_0_weight, metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], p_getattr_l__self___layer4___0___downsample_1_weight, p_getattr_l__self___layer4___0___downsample_1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
   689          lv17 = R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4, (lv15, p_getattr_l__self___layer4___0___conv2_weight, metadata["relax.expr.Constant"][34], metadata["relax.expr.Constant"][35], p_getattr_l__self___layer4___0___bn2_weight, p_getattr_l__self___layer4___0___bn2_bias, lv16), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
   690          lv18 = R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_relu4, (lv17, p_getattr_l__self___layer4___1___conv1_weight, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], p_getattr_l__self___layer4___1___bn1_weight, p_getattr_l__self___layer4___1___bn1_bias), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
   691          lv19 = R.call_tir(fused_conv2d9_subtract4_divide4_expand_dims3_multiply4_expand_dims3_add11_add12_relu4, (lv18, p_getattr_l__self___layer4___1___conv2_weight, metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], p_getattr_l__self___layer4___1___bn2_weight, p_getattr_l__self___layer4___1___bn2_bias, lv17), out_sinfo=R.Tensor((1, 512, 7, 7), dtype="float32"))
   692          lv86 = R.call_tir(adaptive_avg_pool2d, (lv19,), out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"))
   693          lv87 = R.call_tir(reshape, (lv86,), out_sinfo=R.Tensor((1, 512), dtype="float32"))
   694          lv88 = R.call_tir(transpose, (p_fc_weight,), out_sinfo=R.Tensor((512, 1000), dtype="float32"))
   695          lv20 = R.call_tir(fused_matmul_add13, (lv87, lv88, p_fc_bias), out_sinfo=R.Tensor((1, 1000), dtype="float32"))
   696          gv: R.Tuple(R.Tensor((1, 1000), dtype="float32")) = (lv20,)
   697          R.output(gv)
   698      return gv
   699
   700  # Metadata omitted. Use show_meta=True in script() method to show it.
   701
   702  Traceback (most recent call last):
   703    File "/home/rbohl/work/ai_tools/tvm/e2e_debug.py", line 85, in <module>
   704      ex = relax.build(mod, target="cuda")
   705           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   706    File "/home/rbohl/work/ai_tools/tvm/python/tvm/relax/vm_build.py", line 352, in build
   707      return _vmlink(
   708             ^^^^^^^^
   709    File "/home/rbohl/work/ai_tools/tvm/python/tvm/relax/vm_build.py", line 252, in _vmlink
   710      lib = tvm.build(tir_mod, target=target)
   711            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   712    File "/home/rbohl/work/ai_tools/tvm/python/tvm/driver/build_module.py", line 145, in build
   713      rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
   714                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   715    File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
   716    File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
   717    File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
   718    File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
   719    File "/home/rbohl/work/ai_tools/tvm/python/tvm/_ffi/base.py", line 465, in raise_last_ffi_error
   720      raise py_err
   721    File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 445, in operator()
   722      return TIRToRuntime(inputs_arg, host_target);
   723                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   724    File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 406, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
   725      auto pair = SplitMixedModule(ir_module, target, target_host);
   726                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   727    File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 333, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
   728      mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
   729                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   730    File "/home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc", line 287, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
   731      mod = seq(std::move(mod));
   732                      ^^^^^^^^^^^
   733    File "/home/rbohl/work/ai_tools/tvm/src/tir/analysis/verify_memory.cc", line 203, in operator()
   734      LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
   735              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   736  tvm._ffi.base.TVMError: Traceback (most recent call last):
   737    4: operator()
   738          at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:445
   739    3: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
   740          at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:406
   741    2: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
   742          at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:333
   743    1: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
   744          at /home/rbohl/work/ai_tools/tvm/src/driver/driver_api.cc:287
   745    0: operator()
   746          at /home/rbohl/work/ai_tools/tvm/src/tir/analysis/verify_memory.cc:203
   747    Did you forget to bind?
   748      Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   749      Variable `lv4` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   750      Variable `p_getattr_l__self___layer1___0___bn2_bias` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   751      Variable `p_getattr_l__self___layer1___0___bn2_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   752      Variable `param_1` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   753      Variable `param_0` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   754      Variable `p_getattr_l__self___layer1___0___conv2_weight` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   755      Variable `lv8` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
   756    File "/home/rbohl/work/ai_tools/tvm/src/tir/analysis/verify_memory.cc", line 203
   757  RuntimeError: Memory verification failed with the following errors:
   758  # from tvm.script import tir as T
   759
   760  @T.prim_func
   761  def fused_conv2d1_subtract1_divide1_expand_dims_multiply1_expand_dims_add2_add3_relu1(lv8: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), p_getattr_l__self___layer1___0___conv2_weight: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"), param_0: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"), param_1: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"), p_getattr_l__self___layer1___0___bn2_weight: T.Buffer((T.int64(64),), "float32"), p_getattr_l__self___layer1___0___bn2_bias: T.Buffer((T.int64(64),), "float32"), lv4: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32")):
   762      T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-conda-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
   763      pad_temp = T.allocate([215296], "float32", "global")
   764      conv2d_nchw = T.allocate([200704], "float32", "global")
   765      expand_dims = T.allocate([64], "float32", "global")
   766      pad_temp_1 = T.Buffer((T.int64(215296),), data=pad_temp)
   767      for i1, i2, i3 in T.grid(64, 58, 58):
   768          lv8_1 = T.Buffer((T.int64(200704),), data=lv8.data)
   769          pad_temp_1[i1 * 3364 + i2 * 58 + i3] = T.if_then_else(1 <= i2 and i2 < 57 and 1 <= i3 and i3 < 57, lv8_1[i1 * 3136 + i2 * 56 + i3 - 57], T.float32(0.0))
   770      conv2d_nchw_1 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
   771      for ff, yy, xx, rc, ry, rx in T.grid(64, 56, 56, 64, 3, 3):
   772          cse_var_1: T.int32 = ff * 3136 + yy * 56 + xx
   773          if rc == 0 and ry == 0 and rx == 0:
   774              conv2d_nchw_1[cse_var_1] = T.float32(0.0)
   775          p_getattr_l__self___layer1___0___conv2_weight_1 = T.Buffer((T.int64(36864),), data=p_getattr_l__self___layer1___0___conv2_weight.data)
   776          conv2d_nchw_1[cse_var_1] = conv2d_nchw_1[cse_var_1] + pad_temp_1[rc * 3364 + yy * 58 + ry * 58 + xx + rx] * p_getattr_l__self___layer1___0___conv2_weight_1[ff * 576 + rc * 9 + ry * 3 + rx]
   777      conv2d_nchw_2 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
   778      for ax1, ax2, ax3 in T.grid(64, 56, 56):
   779          cse_var_2: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
   780          param_0_1 = T.Buffer((T.int64(64),), data=param_0.data)
   781          conv2d_nchw_2[cse_var_2] = conv2d_nchw_1[cse_var_2] - param_0_1[ax1]
   782      conv2d_nchw_3 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
   783      for ax1, ax2, ax3 in T.grid(64, 56, 56):
   784          cse_var_3: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
   785          param_1_1 = T.Buffer((T.int64(64),), data=param_1.data)
   786          conv2d_nchw_3[cse_var_3] = conv2d_nchw_2[cse_var_3] / param_1_1[ax1]
   787      expand_dims_1 = T.Buffer((T.int64(64),), data=expand_dims)
   788      for i1 in range(64):
   789          p_getattr_l__self___layer1___0___bn2_weight_1 = T.Buffer((T.int64(64),), data=p_getattr_l__self___layer1___0___bn2_weight.data)
   790          expand_dims_1[i1] = p_getattr_l__self___layer1___0___bn2_weight_1[i1]
   791      conv2d_nchw_4 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
   792      for ax1, ax2, ax3 in T.grid(64, 56, 56):
   793          cse_var_4: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
   794          conv2d_nchw_4[cse_var_4] = conv2d_nchw_3[cse_var_4] * expand_dims_1[ax1]
   795      expand_dims_2 = T.Buffer((T.int64(64),), data=expand_dims)
   796      for i1 in range(64):
   797          p_getattr_l__self___layer1___0___bn2_bias_1 = T.Buffer((T.int64(64),), data=p_getattr_l__self___layer1___0___bn2_bias.data)
   798          expand_dims_2[i1] = p_getattr_l__self___layer1___0___bn2_bias_1[i1]
   799      conv2d_nchw_5 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
   800      for ax1, ax2, ax3 in T.grid(64, 56, 56):
   801          cse_var_5: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
   802          conv2d_nchw_5[cse_var_5] = conv2d_nchw_4[cse_var_5] + expand_dims_2[ax1]
   803      conv2d_nchw_6 = T.Buffer((T.int64(200704),), data=conv2d_nchw)
   804      for ax1, ax2, ax3 in T.grid(64, 56, 56):
   805          cse_var_6: T.int32 = ax1 * 3136 + ax2 * 56 + ax3
   806          lv4_1 = T.Buffer((T.int64(200704),), data=lv4.data)
   807          conv2d_nchw_6[cse_var_6] = conv2d_nchw_5[cse_var_6] + lv4_1[cse_var_6]
   808      for i1, i2, i3 in T.grid(64, 56, 56):
   809          cse_var_7: T.int32 = i1 * 3136 + i2 * 56 + i3
   810          compute_intermediate_1 = T.Buffer((T.int64(200704),), data=compute_intermediate.data)
   811          compute_intermediate_1[cse_var_7] = T.max(conv2d_nchw_6[cse_var_7], T.float32(0.0))