Why does meta schedule not find any schedules?

import tvm
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T
import tempfile
import tvm.meta_schedule as ms
from tvm.ir import IRModule

@I.ir_module
class Module:
    @T.prim_func
    def dense_loop(
        VAL: T.handle,
        VEC: T.handle,
        OUT: T.handle,
    ):
        val = T.match_buffer(VAL, (37,), "float64")
        vec = T.match_buffer(VEC, (11,), "float64")
        out = T.match_buffer(OUT, (11,), "float64")

        for j in T.serial(2):
            for i in T.serial(2):
                with T.block("db0"):
                    T.init()
                    out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
        for j in T.serial(1):
            for i in T.serial(2):
                with T.block("db1"):
                    T.init()
                    out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
        for j in T.serial(3):
            for i in T.serial(3):
                with T.block("db3"):
                    T.init()
                    out[i + 2] += val[6 + j * 3 + i] * vec[j + 2]
        for j in T.serial(2):
            for i in T.serial(1):
                with T.block("db5"):
                    T.init()
                    out[i + 5] += val[15 + j * 1 + i] * vec[j + 0]
        for j in T.serial(3):
            for i in T.serial(1):
                with T.block("db8"):
                    T.init()
                    out[i + 5] += val[21 + j * 1 + i] * vec[j + 6]

    @R.function
    def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), dtype="float64")):
        cls = Module
        out = R.call_tir(cls.dense_loop, (val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
        return out

if __name__ == "__main__":
    ...
    target = tvm.target.Target("llvm -num-cores 1")

    with tempfile.TemporaryDirectory() as work_dir:
        database = ms.tir_integration.tune_tir(
            mod=Module,
            target=target,
            work_dir=work_dir,
            max_trials_global=64,
        )

        workload = database.commit_workload(IRModule({"dense_loop": Module["dense_loop"]}))
        top_k = database.get_top_k(workload, 1)
        if not top_k:
            raise RuntimeError("No schedules found in the database.")
        print("Best schedule found:")
        print(top_k[0])

I always hit the “No schedules found in the database” code path. Why is that the case? Am I using meta schedule incorrectly?

You can first optimize the TIR primitive function independently using ms.tir_integration.tune_tir , then merge the optimized version back into the high-level Relax module.

import os
import tvm
import numpy as np
import tvm.meta_schedule as ms

from tvm import relax
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I


@I.ir_module
class Module:
    @T.prim_func
    def dense_loop(
        VAL: T.handle,
        VEC: T.handle,
        OUT: T.handle,
    ):
        val = T.match_buffer(VAL, (37,), "float64")
        vec = T.match_buffer(VEC, (11,), "float64")
        out = T.match_buffer(OUT, (11,), "float64")

        for j in T.serial(2):
            for i in T.serial(2):
                with T.block("db0"):
                    T.init()
                    out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
        for j in T.serial(1):
            for i in T.serial(2):
                with T.block("db1"):
                    T.init()
                    out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
        for j in T.serial(3):
            for i in T.serial(3):
                with T.block("db3"):
                    T.init()
                    out[i + 2] += val[6 + j * 3 + i] * vec[j + 2]
        for j in T.serial(2):
            for i in T.serial(1):
                with T.block("db5"):
                    T.init()
                    out[i + 5] += val[15 + j * 1 + i] * vec[j + 0]
        for j in T.serial(3):
            for i in T.serial(1):
                with T.block("db8"):
                    T.init()
                    out[i + 5] += val[21 + j * 1 + i] * vec[j + 6]

    @R.function
    def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), dtype="float64")):
        cls = Module
        out = R.call_tir(cls.dense_loop, (val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
        return out


mod = Module
# The TIR function that will be tuned
dense_loop_tir = mod["dense_loop"]

target = tvm.target.Target("llvm -num-cores=1")
this_dir = os.path.dirname(os.path.abspath(__file__))
work_dir = os.path.join(this_dir, "tuning_logs")

# Tune the TIR function
database = ms.tir_integration.tune_tir(
    mod=dense_loop_tir,
    target=target,
    work_dir=work_dir,
    max_trials_global=64,
    num_trials_per_iter=16,
)
if database is None:
    raise ValueError("Database is None!")

# Compile the TIR function with the tuned database to a tir.Schedule
sch = ms.tir_integration.compile_tir(database, dense_loop_tir, target)
if sch is None:
    print("No valid schedule found!")
else:
    sch.mod.show()

    # Replace the optimized TIR Prim func back into the original module.
    # In `sch.mod`, the optimized TIR Prim func is stored as `sch.mod["main"]`.
    optimized_tir = sch.mod["main"]
    new_mod = tvm.IRModule({"dense_loop": optimized_tir, "main": mod["main"]})
    new_mod.show()

    # Build new module
    new_mod = relax.transform.LegalizeOps()(new_mod)
    ex = relax.build(new_mod, target=target)
    vm = relax.VirtualMachine(ex, tvm.cpu())

    # Prepare Data
    val_np = np.random.rand(37).astype("float64")
    vec_np = np.random.rand(11).astype("float64")
    val_tvm = tvm.nd.array(val_np, device=tvm.cpu())
    vec_tvm = tvm.nd.array(vec_np, device=tvm.cpu())

    # Execute
    output_tvm = vm["main"](val_tvm, vec_tvm)

    # Output
    output_np = output_tvm.numpy()
    print("Output shape:", output_np.shape)  # out_sinfo: (11,)
    print("Output values:", output_np)

The output:

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def dense_loop(val: T.Buffer((37,), "float64"), vec: T.Buffer((11,), "float64"), out: T.Buffer((11,), "float64")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db0"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 4)
                    T.reads(val[(j_i_fused_0 * 64 + j_i_fused_1) // 2 * 2 + (j_i_fused_0 * 64 + j_i_fused_1) % 2], vec[(j_i_fused_0 * 64 + j_i_fused_1) // 2])
                    T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
                    with T.init():
                        out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] = out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[(j_i_fused_0 * 64 + j_i_fused_1) // 2 * 2 + (j_i_fused_0 * 64 + j_i_fused_1) % 2] * vec[(j_i_fused_0 * 64 + j_i_fused_1) // 2]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db1"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
                    T.reads(val[4 + T.Mul(0, 2) + (j_i_fused_0 * 64 + j_i_fused_1) % 2], vec[T.Add(0, 5)])
                    T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
                    with T.init():
                        out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] = out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[4 + T.Mul(0, 2) + (j_i_fused_0 * 64 + j_i_fused_1) % 2] * vec[T.Add(0, 5)]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db3"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 9)
                    T.reads(val[6 + (j_i_fused_0 * 64 + j_i_fused_1) // 3 * 3 + (j_i_fused_0 * 64 + j_i_fused_1) % 3], vec[(j_i_fused_0 * 64 + j_i_fused_1) // 3 + 2])
                    T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2])
                    with T.init():
                        out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] = out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] + val[6 + (j_i_fused_0 * 64 + j_i_fused_1) // 3 * 3 + (j_i_fused_0 * 64 + j_i_fused_1) % 3] * vec[(j_i_fused_0 * 64 + j_i_fused_1) // 3 + 2]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db5"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
                    T.reads(val[T.Add(15 + (j_i_fused_0 * 64 + j_i_fused_1), 0)], vec[j_i_fused_0 * 64 + j_i_fused_1])
                    T.writes(out[T.Add(0, 5)])
                    with T.init():
                        out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(15 + (j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db8"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 3)
                    T.reads(val[T.Add(21 + (j_i_fused_0 * 64 + j_i_fused_1), 0)], vec[j_i_fused_0 * 64 + j_i_fused_1 + 6])
                    T.writes(out[T.Add(0, 5)])
                    with T.init():
                        out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(21 + (j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1 + 6]
                    T.evaluate(0)

    @R.function
    def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), dtype="float64")) -> R.Tensor((11,), dtype="float64"):
        v = T.int64()
        k = T.int64()
        out = R.call_tir(dense_loop, (val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
        return out

Output shape: (11,)
Output values: [1.56512320e-001 2.21357108e-001 1.09694842e+000 7.23903158e-001
 1.02699930e+000 1.66537513e+000 0.00000000e+000 4.64222119e-310
 4.24399158e-314 0.00000000e+000 0.00000000e+000]

As another optimization strategy, the Relax module can be tuned end-to-end using the methodology demonstrated here:

1 Like