OpenCL ML runtime not available even after building with support

@srkreddy1238

Hi, I built tvm with adreno support but the check for whether or not the runtime is enabled returns false:

this returns false for me when I run on host. Obviously, the host which runs the python interface does not support adreno (I am using an M1 mac). But I am assuming that the is_clml_runtime_enabled check would verify the clml support status of the runtime .so

from tvm.relay.op.contrib import clml
clml.is_clml_runtime_enabled()

This is my build output after following these instructions of yours

python3 -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
USE_NVTX: OFF
USE_GTEST: AUTO
SUMMARIZE: OFF
TVM_DEBUG_WITH_ABI_CHANGE: OFF
USE_IOS_RPC: OFF
USE_MSC: OFF
USE_ETHOSU: OFF
CUDA_VERSION: NOT-FOUND
USE_LIBBACKTRACE: AUTO
DLPACK_PATH: 3rdparty/dlpack/include
USE_TENSORRT_CODEGEN: OFF
USE_THRUST: OFF
USE_TARGET_ONNX: OFF
USE_AOT_EXECUTOR: ON
BUILD_DUMMY_LIBTVM: OFF
USE_CUDNN: OFF
USE_TENSORRT_RUNTIME: OFF
USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR: OFF
USE_CCACHE: AUTO
USE_ARM_COMPUTE_LIB: OFF
USE_CPP_RTVM: OFF
USE_OPENCL_GTEST: /path/to/opencl/gtest
TVM_LOG_BEFORE_THROW: OFF
USE_MKL: OFF
USE_PT_TVMDSOOP: OFF
MLIR_VERSION: NOT-FOUND
USE_CLML: /Users/varunnaw/FastConvolution/src/FastConvolution/adreno_opencl_ml_sdk_v3.0
USE_STACKVM_RUNTIME: OFF
USE_GRAPH_EXECUTOR_CUDA_GRAPH: OFF
ROCM_PATH: /opt/rocm
USE_DNNL: OFF
USE_MSCCL: OFF
USE_VITIS_AI: OFF
USE_MLIR: OFF
USE_RCCL: OFF
USE_LLVM: ON
USE_VERILATOR: OFF
USE_TF_TVMDSOOP: OFF
USE_THREADS: ON
USE_MSVC_MT: OFF
BACKTRACE_ON_SEGFAULT: OFF
USE_GRAPH_EXECUTOR: ON
USE_NCCL: OFF
USE_ROCBLAS: OFF
GIT_COMMIT_HASH: ff884b609a2eb94fef1f061bff0ec867b79d4ba0
USE_VULKAN: OFF
USE_RUST_EXT: OFF
USE_CUTLASS: OFF
USE_CPP_RPC: OFF
USE_HEXAGON: OFF
USE_CUSTOM_LOGGING: OFF
USE_UMA: OFF
USE_FALLBACK_STL_MAP: OFF
USE_SORT: ON
USE_RTTI: ON
GIT_COMMIT_TIME: 2024-09-06 12:28:28 -0400
USE_HIPBLAS: OFF
USE_HEXAGON_SDK: /path/to/sdk
USE_BLAS: none
USE_ETHOSN: OFF
USE_LIBTORCH: OFF
USE_RANDOM: ON
USE_CUDA: OFF
USE_COREML: OFF
USE_AMX: OFF
BUILD_STATIC_RUNTIME: OFF
USE_CMSISNN: OFF
USE_KHRONOS_SPIRV: OFF
USE_CLML_GRAPH_EXECUTOR: OFF
USE_TFLITE: OFF
USE_HEXAGON_GTEST: /path/to/hexagon/gtest
PICOJSON_PATH: 3rdparty/picojson
USE_OPENCL_ENABLE_HOST_PTR: OFF
INSTALL_DEV: OFF
USE_PROFILER: ON
USE_NNPACK: OFF
LLVM_VERSION: 18.1.8
USE_MRVL: OFF
USE_OPENCL: OFF
COMPILER_RT_PATH: 3rdparty/compiler-rt
RANG_PATH: 3rdparty/rang/include
USE_SPIRV_KHR_INTEGER_DOT_PRODUCT: OFF
USE_OPENMP: none
USE_BNNS: OFF
USE_FLASHINFER: OFF
USE_CUBLAS: OFF
USE_METAL: OFF
USE_MICRO_STANDALONE_RUNTIME: OFF
USE_HEXAGON_EXTERNAL_LIBS: OFF
USE_ALTERNATIVE_LINKER: AUTO
USE_BYODT_POSIT: OFF
USE_NVSHMEM: OFF
USE_HEXAGON_RPC: OFF
USE_MICRO: OFF
DMLC_PATH: 3rdparty/dmlc-core/include
INDEX_DEFAULT_I64: ON
USE_RELAY_DEBUG: OFF
USE_RPC: ON
USE_TENSORFLOW_PATH: none
TVM_CLML_VERSION: 3
USE_MIOPEN: OFF
USE_ROCM: OFF
USE_PAPI: OFF
USE_CURAND: OFF
TVM_CXX_COMPILER_PATH: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++
HIDE_PRIVATE_SYMBOLS: OFF

How do I make it so that my CNN which is tuned via Auto scheduler (Ansor) use clml?

complete code to reproduce, make sure to change device key (it is currently set to “android”)

import os
import sys

import numpy as np

import tvm
from tvm import relay, auto_scheduler
import tvm.relay.testing
from tvm.contrib import graph_executor

def get_network(name, batch_size, layout="NHWC", dtype="float32"):
    """Get the symbol definition and random weight of a network"""

    # auto-scheduler prefers NHWC layout
    if layout == "NHWC":
        image_shape = (224, 224, 3)
    elif layout == "NCHW":
        image_shape = (3, 224, 224)
    else:
        raise ValueError("Invalid layout: " + layout)

    input_shape = (batch_size,) + image_shape
    output_shape = (batch_size, 1000)

    if name.startswith("resnet-"):
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer,
            batch_size=batch_size,
            layout=layout,
            dtype=dtype,
            image_shape=image_shape,
        )
    elif name.startswith("resnet3d-"):
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer,
            batch_size=batch_size,
            layout=layout,
            dtype=dtype,
            image_shape=image_shape,
        )
    elif name == "mobilenet":
        mod, params = relay.testing.mobilenet.get_workload(
            batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
        )
    elif name == "squeezenet_v1.1":
        assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
        mod, params = relay.testing.squeezenet.get_workload(
            version="1.1",
            batch_size=batch_size,
            dtype=dtype,
            image_shape=image_shape,
        )
    elif name == "inception_v3":
        input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)

    return mod, params, input_shape, output_shape


# Define the neural network and compilation target.
network = "mobilenet"
batch_size = 1
layout = "NHWC"
# Set this to True if you use ndk tools for cross compiling
use_ndk = False
# Path to cross compiler
os.environ["TVM_NDK_CC"] = "/Users/varunnaw/Library/Android/sdk/ndk/26.1.10909125/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android28-clang"
target = tvm.target.Target("opencl -device=adreno", host="llvm -mtriple=aarch64-linux-android")
dtype = "float16"
log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, target.kind.name)

# Replace this with the device key in your tracker
device_key = "android"

from tvm.auto_scheduler.utils import request_remote
remote = request_remote(device_key, "127.0.0.1", 9190)
dev = remote.cl()
max_shared_memory_per_block = dev.max_shared_memory_per_block
# There is no explicit local memory limition
# so we can use INT32_MAX to disable the check on local_memory.
max_local_memory_per_block = 2147483647 # INT32_MAX
max_threads_per_block = dev.max_threads_per_block
max_vthread_extent = int(dev.warp_size / 4) if int(dev.warp_size / 4) > 1 else dev.warp_size
warp_size = dev.warp_size
hardware_params = auto_scheduler.HardwareParams(-1, 16, 64,
                                                max_shared_memory_per_block, max_local_memory_per_block,
                                                max_threads_per_block, max_vthread_extent, warp_size)

# Extract tasks from the network
print("Extract tasks...")
mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype)
# tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)


tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target, hardware_params = hardware_params)

for idx, task in enumerate(tasks):
    print("========== Task %d  (workload key: %s) ==========" % (idx, task.workload_key))
    print(task.compute_dag)

def tune_and_evaluate():
    print("Begin tuning...")
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=30,  # change this to 20000 to achieve the best performance
        builder=auto_scheduler.LocalBuilder(build_func="ndk" if use_ndk else "default"),
        runner=auto_scheduler.RPCRunner(
            device_key, host="127.0.0.1", port=9190, repeat=1, timeout=50
        ),
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    tuner.tune(tune_option)

    # Compile the whole network
    print("Compile...")
    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(
            opt_level=3, config={"relay.backend.use_auto_scheduler": True}
        ):
            lib = relay.build(mod, target, params=params)

    # Create graph executor
    print("=============== Request Remote ===============")
    from tvm.auto_scheduler.utils import request_remote

    remote = request_remote(device_key, "127.0.0.1", 9190)
    dev = remote.cl()
    from tvm.contrib import utils, ndk

    temp = utils.tempdir()
    filename = "deploy_lib.so"
    path_lib = temp.relpath(filename)
    lib.export_library(path_lib, fcompile=ndk.create_shared)
    remote.upload(path_lib)
    loaded_lib = remote.load_module(filename)
    module = graph_executor.GraphModule(loaded_lib["default"](dev))
    
    data = (np.random.uniform(size=input_shape)).astype(dtype)
    data_tvm = tvm.nd.array(data)
    module.set_input("data", data_tvm)

    
    # Evaluate
    print("Evaluate inference time cost...")
    print(module.benchmark(dev, repeat=3, min_repeat_ms=500))

tune_and_evaluate()

Where do I put

mod = clml.partition_for_clml(mod, params)

Where do I put this, which is mentioned in this qualcomm post by Siva

Hi Varun

CLML enablement required below settings set(USE_CLML “${ADRENO_OPENCL}”) set(USE_CLML_GRAPH_EXECUTOR “${ADRENO_OPENCL}”)

CLML partitioning call to be placed just before relay.build. It basically separates the CLML supported calls to different runtime. Left over operations will continue to use TVM;s native implementation (untuned or autotvm tuned or auto scheduler tuned …etc.)

Hope this helps

Thanks, Siva