The auto-scheduler fails to measure candidate programs due to:
InternalError: Check failed: (bf != nullptr) is false: target.build.rocm is not enabled
I have compiled TVM v0.19 from source using the tag v0.19.0 from the cloned GitHub repository. I have used amdclang and amdclang++ version 17.0.0 from rocm 6.1.0 as C and C++ compilers to compile. I have compiled by explicitly setting (USE_CUDA as per default OFF checked it twice):
-DUSE_OPENMP=ON
-DUSE_ROCBLAS=ON
-DUSE_ROCM=ON
TVM can detect ROCm GPU and copy data to GPU, but the auto-scheduler fails. Could this be an issue with the auto-scheduler, or should I compile it using different flags? What the issue could be? (In my system ROCm and HIP are functional, I have been able to use ROCm/ROCBLAS and compile and run programs using HIP)
I also have a reproducer that reproduces my error:
import tvm
from tvm import te
from tvm.contrib import rocm
import numpy as np
def test_rocm_gpu_detection():
devices = tvm.device("rocm", 0)
if devices.exist:
print("ROCm GPU detected!")
else:
print("No ROCm GPU detected.")
def to_tvm_tensor(x):
if isinstance(x, np.ndarray):
# Convert NumPy array to Torch tensor on GPU
return tvm.nd.array(x, tvm.rocm(0))
else:
raise TypeError("Input should be a NumPy array.")
x = np.random.rand(10, 10).astype(np.float64)
y = np.random.rand(10, 10).astype(np.float64)
tx = to_tvm_tensor(x)
ty = to_tvm_tensor(y)
print(tx)
return tx, ty
import tvm
from tvm import te
from tvm import auto_scheduler
import tvm.testing
from tvm import autotvm
@auto_scheduler.register_workload("jacobi_2d_1_gpu")
def jacobi_2d_1_gpu(N, dtype):
A = te.placeholder((N, N), name="A", dtype=dtype)
def compute_step(A):
return te.compute(
(N, N),
lambda i, j:
te.if_then_else(
te.all(i >= 1, i < N-1, j >= 1, j < N-1),
0.2 * (
A[i, j] + # center
A[i, j - 1] + # left
A[i, j + 1] + # right
A[i - 1, j] + # top
A[i + 1, j] # bottom
),
A[i, j]
),
name="B_comp"
)
B_comp = compute_step(A)
return [A, B_comp]
@auto_scheduler.register_workload("jacobi_2d_2_gpu")
def jacobi_2d_2_gpu(N, dtype):
B = te.placeholder((N, N), name="B", dtype=dtype)
def compute_step(B):
return te.compute(
(N, N),
lambda i, j:
te.if_then_else(
te.all(i >= 1, i < N-1, j >= 1, j < N-1),
0.2 * (
B[i, j] + # center
B[i, j - 1] + # left
B[i, j + 1] + # right
B[i - 1, j] + # top
B[i + 1, j] # bottom
),
B[i, j]
),
name="B_comp"
)
A_comp = compute_step(B)
return [A_comp, B]
_kernel1 = None
def autotune(func, name, args, target):
import tvm
from tvm import autotvm
task = auto_scheduler.SearchTask(func=func, args=args, target=target)
# I have also tried:
# target_host = tvm.target.Target("llvm")
# task = auto_scheduler.SearchTask(func=func, args=args, target=target, target_host=target_host)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=2000,
measure_callbacks=[auto_scheduler.RecordToFile(f"{name}.json")],
verbose=2,
)
# Run the search
task.tune(tuning_options=tune_option, search_policy=auto_scheduler.SketchPolicy(task))
sch, args = task.apply_best(f"{name}.json")
with tvm.target.Target(target):
_kernel = tvm.build(sch, args, target=tvm.target.rocm())
_kernel(*copy.deepcopy(args))
return _kernel
def autotuner(TSTEPS, A, B):
global _kernel1
if _kernel1 is not None and _kernel2 is not None:
return
dtype = A.dtype
M = int(A.shape[0])
N = int(A.shape[1])
assert M == N
_kernel1 = autotune(func=jacobi_2d_1_gpu, name="jacobi_2d_1_gpu", args=(N, dtype), target=tvm.target.rocm())
def kernel(TSTEPS, A, B):
global _kernel1
for _ in range(1, TSTEPS):
_kernel1(A, B)
return A
if __name__ == "__main__":
tx, ty = test_rocm_gpu_detection()
autotuner(20, tx, ty)