Error when compiling model with rocm

Hi, currently I’m trying to compile the sample model with rocm. I use the model from the document Compiling and Optimizing a Model with TVMC, and just simply modified the target option from “llvm” to “rocm”, and I got some LLVM error:

TVMError: LLVM module verification failed with the following errors: Instruction does not dominate all uses! %29 = call i32 @llvm.amdgcn.workitem.id.x() %83 = icmp slt i32 %29, 2 Instruction does not dominate all uses! %29 = call i32 @llvm.amdgcn.workitem.id.x() %97 = mul nsw i32 %29, 4 Instruction does not dominate all uses! %29 = call i32 @llvm.amdgcn.workitem.id.x() %110 = mul nsw i32 %29, 4 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.1 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.2 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.3 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.4 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.5 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.6 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.7 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.8 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.9 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.10 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.11 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.12 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.13 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.14 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.15 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.16 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.17 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.18 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.19 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.20 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.21 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.22 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.23 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.24 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.25 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.26 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.27 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.28 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.29 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.30 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.31 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.32 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.33 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.34 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.35 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.36 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.37 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.38 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.39 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.40 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.41 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.42 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.43 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.44 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.45 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.46 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.47 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.48 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.49 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.50 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.51 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.52 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.53 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.54 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.55 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.56 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.57 Global is external, but doesn't have external or weak linkage! ptr addrspace(3) @shmem.58

Anyone know how to solve this issue?

This is a known issue, some related discussion can be found here, #14901 (comment) , the community still seems to be trying to figure it out. @Lunderberg

Just for a suggestion, a quick solution you can try is rollling back to a old release of tvm, or you can also take a look this unofficial hip source codegen backend. hip mfma

1 Like

are there any updates on this issue? i am facing the same issue when i am trying to build resnet18 through tvm for rocm as target. Any help is appreciated , thanks

For what it is worth, I can also see this issue when building with commit #c6f281. I’ve done some investigation and believe:

  1. This general class of problem first appeared after PR #14564 (@Lunderberg) - but what this PR did was move the calls to the validation function earlier and so likely uncovered potentially latent issues before.
  2. As @LeiWang1999 notes, there was discussion around PR #14901. I believe that introduced a new validation issue that was later fixed (returning void); while @LeiWang1999 was able to then work with her backend, the basic issue with resnet-18 likely still there.
  3. There have been subsequent fixes in PR #15777 (@junrushao) and changes in PR #15464 (@spectrometerHBH) - but I still see the problem with resnet18 even after those fixes.
  4. The easiest way I have of reproducing the issue is: cd /source/tvm/apps/benchmark python3 gpu_imagenet_bench.py --target=rocm --network=resnet-18
  5. I build ROCm TVM using the following dockerfile: https://github.com/mvermeulen/rocm-tvm/blob/master/dockerfile/rocm-tvm%3A5.7

I am still trying to characterize the problem and what it is telling me, but if there is additional information any of the folks tagged need let me know (also @masahi for awareness).

Yeah my fix was related to a case in Llama we encountered, but I’m not super sure about ResNet right now. Do you want to give a minimal TIR example so that @Lunderberg or I could take a look?

I am still learning how to create and narrow cases of TIR. In the meantime, I posted the tail end of a debug log here - https://github.com/mvermeulen/rocm-tvm/blob/master/tutorial/quickstart_rocm_log.txt

What I did for this log is:

  • Built TVM with USE_RELAY_DEBUG ON
  • Changed the src/tvm/gallery/tutorial/relay_quick_start.py example after replacing cuda with rocm
  • Ran the quick start example with TVM_LOG_DEBUG=DEFAULT=2
  • Included only the last @T.prim_func and all the lines from there until crash with traceback

Can that logging help describe what TVM was processing when it crashed? Are there different debug flags I should use to instrument? Are there other pointers I can learn better on narrowing this example without too much work (either logging or pruning or ???)

Unfortunately the generated TIR you shared is CPU kernel launching logic rather than the ROCm kernel itself. How about this:

adding

LOG(INFO) << "Adding PrimFunc:\n" << f;

Prior to this line: https://github.com/apache/tvm/blob/853732e5efea55da584d7660a9123eee09819d42/src/target/llvm/codegen_llvm.cc#L233

This will dump all the PrimFuncs for TIR to generate code for.

BTW, don’t use USE_RELAY_DEBUG or TVM_LOG_DEBUG as they usually just dump too much irrelevant information.

I added logging when CodeGenLLVM::AddFunction and CodeGenAMDGPU::AddFunction are called. I have included the last @T.prim_func below. I also put a copy of the full output file in the following directory: https://github.com/mvermeulen/rocm-tvm/tree/master/apps under resnet18.txt if the formatting is messed up below.

For curiosity, I also tried the other models in gpu_imagenet_bench.py. They behave as follows:

  • resnet18, mobilenet, vgg-16 all fail with what looks like a fused softmax kernel being compiled.
  • densenet121 fails with a fused maxpool kernel
  • squeezenet 1.1 has no failures and reports a benchmark result.

`[22:57:22] /src/tvm/src/target/llvm/codegen_amdgpu.cc:92: Adding PrimFunc GPU:

from tvm.script import tir as T

@T.prim_func def tvmgen_default_fused_nn_softmax_kernel(T_softmax_norm: T.handle(“float32”), p0: T.handle(“float32”)): T.func_attr({“calling_conv”: 2, “target”: T.target({“host”: {“keys”: [“cpu”], “kind”: “llvm”, “tag”: “”}, “keys”: [“rocm”, “gpu”], “kind”: “rocm”, “max_num_threads”: 256, “max_shared_memory_per_block”: 65536, “max_threads_per_block”: 256, “mcpu”: “gfx906”, “model”: “unknown”, “mtriple”: “amdgcn-amd-amdhsa-hcc”, “tag”: “”, “thread_warp_size”: 64}), “tir.is_global_func”: T.bool(True), “tir.kernel_launch_params”: [“blockIdx.x”, “threadIdx.x”], “tir.noalias”: T.bool(True)}) T_softmax_norm_1 = T.decl_buffer((1000,), data=T_softmax_norm) red_buf0 = T.handle(“float32”, “shared”) red_buf0_1 = T.decl_buffer((64,), data=red_buf0, scope=“shared”) normal_reduce_temp0 = T.handle(“float32”, “local”) normal_reduce_temp0_1 = T.decl_buffer((1,), data=normal_reduce_temp0, scope=“local”) T_softmax_exp = T.handle(“float32”, “local”) T_softmax_exp_1 = T.decl_buffer((1000,), data=T_softmax_exp, scope=“local”) red_buf0_2 = T.handle(“float32”, “shared”) red_buf0_3 = T.decl_buffer((64,), data=red_buf0_2, scope=“shared”) p0_1 = T.decl_buffer((1000,), data=p0) normal_reduce_temp0_2 = T.handle(“float32”, “local”) normal_reduce_temp0_3 = T.decl_buffer((1,), data=normal_reduce_temp0_2, scope=“local”) blockIdx_x = T.launch_thread(“blockIdx.x”, 1) normal_reduce_temp0_2 = T.allocate([1], “float32”, “local”) red_buf0_2 = T.allocate([64], “float32”, “shared”) T.attr(red_buf0_2, “volatile_scope”, 1) T_softmax_exp = T.allocate([16], “float32”, “local”) normal_reduce_temp0 = T.allocate([1], “float32”, “local”) red_buf0 = T.allocate([64], “float32”, “shared”) T.attr(red_buf0, “volatile_scope”, 1) threadIdx_x = T.env_thread(“threadIdx.x”) with T.launch_thread(threadIdx_x, 64): normal_reduce_temp0_3[0] = T.float32(-3.4028234663852886e+38) for k_inner in range(16): if threadIdx_x * 2 + T.shift_right(k_inner, 3) < 125: normal_reduce_temp0_3[0] = T.max(normal_reduce_temp0_3[0], p0_1[threadIdx_x * 16 + k_inner]) with T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e+38)]), “reduce_scope”, T.reinterpret(“handle”, T.uint64(0))): T.tvm_storage_sync(“shared”) red_buf0_3[threadIdx_x] = normal_reduce_temp0_3[0] T.tvm_storage_sync(“shared”) if threadIdx_x < 32: with T.LetStmt(T.max(red_buf0_3[threadIdx_x], red_buf0_3[threadIdx_x + 32])) as w_32_0: T.tvm_storage_sync(“warp”) red_buf0_3[threadIdx_x] = w_32_0 T.tvm_storage_sync(“warp”) with T.LetStmt(T.max(red_buf0_3[threadIdx_x], red_buf0_3[threadIdx_x + 16])) as w_16_0: T.tvm_storage_sync(“warp”) red_buf0_3[threadIdx_x] = w_16_0 T.tvm_storage_sync(“warp”) with T.LetStmt(T.max(red_buf0_3[threadIdx_x], red_buf0_3[threadIdx_x + 8])) as w_8_0: T.tvm_storage_sync(“warp”) red_buf0_3[threadIdx_x] = w_8_0 T.tvm_storage_sync(“warp”) with T.LetStmt(T.max(red_buf0_3[threadIdx_x], red_buf0_3[threadIdx_x + 4])) as w_4_0: T.tvm_storage_sync(“warp”) red_buf0_3[threadIdx_x] = w_4_0 T.tvm_storage_sync(“warp”) with T.LetStmt(T.max(red_buf0_3[threadIdx_x], red_buf0_3[threadIdx_x + 2])) as w_2_0: T.tvm_storage_sync(“warp”) red_buf0_3[threadIdx_x] = w_2_0 T.tvm_storage_sync(“warp”) w_1_0: T.float32 = T.max(red_buf0_3[threadIdx_x], red_buf0_3[threadIdx_x + 1]) T.tvm_storage_sync(“warp”) red_buf0_3[threadIdx_x] = w_1_0 T.tvm_storage_sync(“warp”) T.tvm_storage_sync(“shared”) for i1_inner_outer in range(4): if threadIdx_x * 2 + T.shift_right(i1_inner_outer, 1) < 125: T_softmax_exp_1[i1_inner_outer * 4:i1_inner_outer * 4 + 4] = T.call_pure_extern(“float32x4”, “__ocml_exp_f32”, p0_1[threadIdx_x * 16 + i1_inner_outer * 4:threadIdx_x * 16 + i1_inner_outer * 4 + 4] - T.Broadcast(red_buf0_3[0], 4)) T.launch_thread(threadIdx_x, 64) normal_reduce_temp0_1[0] = T.float32(0) T.tvm_storage_sync(“warp”) for k_inner in range(16): if threadIdx_x * 2 + T.shift_right(k_inner, 3) < 125: normal_reduce_temp0_1[0] = normal_reduce_temp0_1[0] + T_softmax_exp_1[k_inner] with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), “reduce_scope”, T.reinterpret(“handle”, T.uint64(0))): T.tvm_storage_sync(“shared”) red_buf0_1[threadIdx_x] = normal_reduce_temp0_1[0] T.tvm_storage_sync(“shared”) if threadIdx_x < 32: with T.LetStmt(red_buf0_1[threadIdx_x] + red_buf0_1[threadIdx_x + 32]) as w_32_0: T.tvm_storage_sync(“warp”) red_buf0_1[threadIdx_x] = w_32_0 T.tvm_storage_sync(“warp”) with T.LetStmt(red_buf0_1[threadIdx_x] + red_buf0_1[threadIdx_x + 16]) as w_16_0: T.tvm_storage_sync(“warp”) red_buf0_1[threadIdx_x] = w_16_0 T.tvm_storage_sync(“warp”) with T.LetStmt(red_buf0_1[threadIdx_x] + red_buf0_1[threadIdx_x + 8]) as w_8_0: T.tvm_storage_sync(“warp”) red_buf0_1[threadIdx_x] = w_8_0 T.tvm_storage_sync(“warp”) with T.LetStmt(red_buf0_1[threadIdx_x] + red_buf0_1[threadIdx_x + 4]) as w_4_0: T.tvm_storage_sync(“warp”) red_buf0_1[threadIdx_x] = w_4_0 T.tvm_storage_sync(“warp”) with T.LetStmt(red_buf0_1[threadIdx_x] + red_buf0_1[threadIdx_x + 2]) as w_2_0: T.tvm_storage_sync(“warp”) red_buf0_1[threadIdx_x] = w_2_0 T.tvm_storage_sync(“warp”) w_1_0: T.float32 = red_buf0_1[threadIdx_x] + red_buf0_1[threadIdx_x + 1] T.tvm_storage_sync(“warp”) red_buf0_1[threadIdx_x] = w_1_0 T.tvm_storage_sync(“warp”) T.tvm_storage_sync(“shared”) for i1_inner_outer in range(4): if threadIdx_x * 2 + T.shift_right(i1_inner_outer, 1) < 125: T_softmax_norm_1[threadIdx_x * 16 + i1_inner_outer * 4:threadIdx_x * 16 + i1_inner_outer * 4 + 4] = T_softmax_exp_1[i1_inner_outer * 4:i1_inner_outer * 4 + 4] / T.Broadcast(red_buf0_1[0], 4)

Traceback (most recent call last): File “/src/rocm-tvm/tutorial/relay_quick_start.py”, line 102, in lib = relay.build(mod, target, params=params) File “/usr/local/lib/python3.10/dist-packages/tvm-0.7.0.dev6030+g4de435be7-py3.10-linux-x86_64.egg/tvm/relay/build_module.py”, line 364, in build graph_json, runtime_mod, params = bld_mod.build( File “/usr/local/lib/python3.10/dist-packages/tvm-0.7.0.dev6030+g4de435be7-py3.10-linux-x86_64.egg/tvm/relay/build_module.py”, line 161, in build self._build( File “/usr/local/lib/python3.10/dist-packages/tvm-0.7.0.dev6030+g4de435be7-py3.10-linux-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py”, line 239, in call raise_last_ffi_error() File “/usr/local/lib/python3.10/dist-packages/tvm-0.7.0.dev6030+g4de435be7-py3.10-linux-x86_64.egg/tvm/_ffi/base.py”, line 481, in raise_last_ffi_error raise py_err tvm._ffi.base.TVMError: Traceback (most recent call last): 8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtrtvm::runtime::Object const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 7: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&) 6: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&) 5: tvm::codegen::Build(tvm::IRModule, tvm::Target) 4: _ZN3tvm7runtime13PackedFuncObj 3: tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module ()(tvm::IRModule, tvm::Target)>(tvm::runtime::Module ()(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 2: tvm::codegen::BuildAMDGPU(tvm::IRModule, tvm::Target) 1: tvm::codegen::CodeGenLLVM::Finish() 0: tvm::codegen::CodeGenLLVM::Verify() const File “/src/tvm/src/target/llvm/codegen_llvm.cc”, line 355 TVMError: LLVM module verification failed with the following errors: Instruction does not dominate all uses! %29 = call i32 @llvm.amdgcn.workitem.id.x() %83 = icmp slt i32 %29, 2 Instruction does not dominate all uses! %29 = call i32 @llvm.amdgcn.workitem.id.x() %90 = mul nsw i32 %29, 4 Instruction does not dominate all uses! %29 = call i32 @llvm.amdgcn.workitem.id.x() %103 = mul nsw i32 %29, 4

`

Strange observation I also make looking at the TIR dumps:

  1. The last kernel before crashing densenet-121 is tvmgen_default_fused_nn_max_pool2d_add_nn_relu_kernel. This textually exact same kernel appears as the N-1 GPU kernel in resnet-18. However, it isn’t until the next kernel tvmgen_default_fused_nn_softmax_kernel that resnet-18 gets an assert.

It may not mean anything, but I got there because softmax kernel was larger and was curious if the densenet kernel might be easier to investigate but then noticed this same kernel was also in resnet. Also point it out here is that gives any clues or places I can help look.

I took a look at the LLVM IR it is complaining about. It looks like a variation of what @junrushao fixed with PR #15777.

  1. The results of @llvm.workgroup.id.* and @llvm.workitem.id.* are returned (with %29) and cast as int32. PR#15777 added that explicit cast.
  2. Later uses of this result are used in other contexts inside a loop such as a “slt” comparison or “mul” instruction. These look like “if” and index operations to me. Both the intrinsic call and the if/index operations are inside a loop. For example, flagging:
%29 = call i32 @llvm.amdgcn.workitem.id.x()
%83 = icmp slt i32 i32, %29, 2

with message “Instruction does not dominate all uses!”

I don’t immediately see the issue here unless we have some type of conflict with upper bits of a datatype, e.g. if that constant “2” was 64 bits or something with the control flow. However, the same control flow should apply to other targets like nvptx so more likely data types?

I have turned off the symptoms in my local build by removing the Verify() call before the Optimize() call in CodeGenLLVM::Finish(). That is not a general solution since it removes a check for symptoms that @Lunderberg added with PR 14564, but it lets me proceed without crashing here. I don’t know enough how to fix this more generally and appreciate any help.