In TVM, do we have to manually specify a schedule?

Hi, I was following this TVM tutorial https://github.com/dlsyscourse/public_notebooks/blob/main/24_machine_learning_compilation_deployment_implementation.ipynb, and was wondering if there’s some “automatic” or “default” schedule to optimize the program, like this:

A = te.placeholder(shape=(128,), dtype="float32", name="A")
B = te.placeholder(shape=(128,), dtype="float32", name="B")
C = te.compute((128,), lambda i: A[i] + B[i], name="C")
func = te.create_prim_func([A, B, C])
func = func.with_attr("global_symbol", "main")
ir_module = IRModule({"main": func})
ir_module.optimize().  # does this exist?

So that I only need to specify the computation itself, and don’t need to know anythin about the schedule at all. I was thinking something equivalent to this conceptually:

@tvm.compile
def vec_add(a, b):
    return a + b

And somehow TVM will optimize the function without the user specifying any schedule. Thanks!

Yes there is, and that’s what we called auto-scheduling in TVM. Please check the 4. Automatic Program Optimization — Machine Learing Compilation 0.0.1 documentation for how to use TVM’s latest meta schedule to tune programs.

For example, you can use default auto-scheduling with your code in the following way:

import tvm
import numpy as np
from tvm import te
from tvm import meta_schedule as ms

A = te.placeholder(shape=(128,), dtype="float32", name="A")
B = te.placeholder(shape=(128,), dtype="float32", name="B")
C = te.compute((128,), lambda i: A[i] + B[i], name="C")
func = te.create_prim_func([A, B, C])
func = func.with_attr("global_symbol", "main")
mod = tvm.IRModule.from_expr(func)

# auto scheduling
database = ms.tune_tir(
    mod=mod,
    target="llvm --num-cores=1",
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir="./tune_tmp",
    task_name="main"
)
sch = ms.tir_integration.compile_tir(database, mod, "llvm --num-cores=1")

# build and evaluate
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator(lib.entry_name, tvm.cpu(0))
a_nd = tvm.nd.array(np.random.randn(128).astype("float32"))
b_nd = tvm.nd.array(np.random.randn(128).astype("float32"))
c_nd = tvm.nd.empty((128,), dtype="float32")
print("Time cost of function after tuning: {:.3f} ms".format(f_timer_after(a_nd, b_nd, c_nd).mean * 1000))

Thank you very much for pointing me to the auto scheduling page and the code example, it’s incredibly helpful!!

Two more questions

  • If the compile target is cuda, can I use the same auto scheduling code? and just replace “llvm --num-cores=1” with “cuda” maybe?
  • tvm.nd.array(np.random.randn(128).astype("float32")), does this copy the underlying data array?

Thanks!

Actually I’ve found an example from 6.1. Part 1 — Machine Learing Compilation 0.0.1 documentation, but got an error with the following code:

database = ms.tune_tir(
    mod=ir_module,
    target="nvidia/tesla-p100",
    max_trials_global=64,
    num_trials_per_iter=64,
    
    work_dir="./tune_tmp",
    task_name="main"
)
sch = ms.tir_integration.compile_tir(database, ir_module, "nvidia/tesla-p100")
rt_mod = tvm.build(sch.mod, target="nvidia/tesla-p100")
dev = tvm.cuda(0)
evaluator = rt_mod.time_evaluator("main", dev, number=10)
A_np = np.random.uniform(size=(1024, 1024)).astype("float32")
B_np = np.random.uniform(size=(1024, 1024)).astype("float32")
A_nd = tvm.nd.array(A_np, dev)
B_nd = tvm.nd.array(B_np, dev)
C_nd = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev)
print("MetaSchedule: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))

The tvm.build(sch.mod, target="nvidia/tesla-p100") line is throwing an error:

  Did you forget to bind?
    Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `C` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/home/tzhou80/projects/tvm/src/tir/analysis/verify_memory.cc", line 214
RuntimeError: Memory verification failed with the following errors:
PrimFunc([var_A, var_B, var_C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main", "target": cuda -keys=cuda,gpu -arch=sm_60 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32} {
  parallel (i0_fused, 0, 128) {
    C[i0_fused] = (A[i0_fused] + B[i0_fused])
  }
}

It looks like tvm.build is somehow accessing some data that’s not on the device?

I am also having the same issue when running on cuda.

@yzh119 @yzh119
Do you have any ideas?

@twmht Hi, if you want to run on CUDA, you would need to do tune_tir with the target being "cuda", and also do compile_tir with the target being "cuda" (or substitute "cuda" with a specific target string in the list here https://github.com/apache/tvm/blob/main/src/target/tag.cc#L126-L378 (e.g., target="nvidia/geforce-rtx-3090-ti").

The code snippet you shared is exclusive for LLVM and is not applicable to CUDA. Hope this can help.

1 Like

Hi, I am also encounter this error message, do you happen to solved this error? Thank you.

Did you forget to bind?
Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `C` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `C` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `C` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
File "/workspace/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
    T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    for i, j, k in T.grid(1024, 1024, 1024):
        cse_var_2: T.int32 = i * 1024
        cse_var_1: T.int32 = cse_var_2 + j
        C_1 = T.Buffer((1048576,), data=C.data)
        if k == 0:
            C_1[cse_var_1] = T.float32(0)
        A_1 = T.Buffer((1048576,), data=A.data)
        B_1 = T.Buffer((1048576,), data=B.data)
        C_1[cse_var_1] = C_1[cse_var_1] + A_1[cse_var_2 + k] * B_1[j * 1024 + k]

@bnovike getting same err – did you find a solution?

Hi, in my case, downgrading xgboost package version to 1.4.2 will help.

but im not even using xgboost: Memory verification failed with Relax

The error message I encounter is when doing meta schedule tuning and compiling. Meta schedule tuning needs xgboost to run. So when I trace the code, the xgboost version is the main problem. I am not sure what is the main problem in your case.