import tvm
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T
import tempfile
import tvm.meta_schedule as ms
from tvm.ir import IRModule
@I.ir_module
class Module:
@T.prim_func
def dense_loop(
VAL: T.handle,
VEC: T.handle,
OUT: T.handle,
):
val = T.match_buffer(VAL, (37,), "float64")
vec = T.match_buffer(VEC, (11,), "float64")
out = T.match_buffer(OUT, (11,), "float64")
for j in T.serial(2):
for i in T.serial(2):
with T.block("db0"):
T.init()
out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
for j in T.serial(1):
for i in T.serial(2):
with T.block("db1"):
T.init()
out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
for j in T.serial(3):
for i in T.serial(3):
with T.block("db3"):
T.init()
out[i + 2] += val[6 + j * 3 + i] * vec[j + 2]
for j in T.serial(2):
for i in T.serial(1):
with T.block("db5"):
T.init()
out[i + 5] += val[15 + j * 1 + i] * vec[j + 0]
for j in T.serial(3):
for i in T.serial(1):
with T.block("db8"):
T.init()
out[i + 5] += val[21 + j * 1 + i] * vec[j + 6]
@R.function
def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), dtype="float64")):
cls = Module
out = R.call_tir(cls.dense_loop, (val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
return out
if __name__ == "__main__":
...
target = tvm.target.Target("llvm -num-cores 1")
with tempfile.TemporaryDirectory() as work_dir:
database = ms.tir_integration.tune_tir(
mod=Module,
target=target,
work_dir=work_dir,
max_trials_global=64,
)
workload = database.commit_workload(IRModule({"dense_loop": Module["dense_loop"]}))
top_k = database.get_top_k(workload, 1)
if not top_k:
raise RuntimeError("No schedules found in the database.")
print("Best schedule found:")
print(top_k[0])
I always hit the “No schedules found in the database” code path. Why is that the case? Am I using meta schedule incorrectly?
You can first optimize the TIR primitive function independently using ms.tir_integration.tune_tir , then merge the optimized version back into the high-level Relax module.
import os
import tvm
import numpy as np
import tvm.meta_schedule as ms
from tvm import relax
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I
@I.ir_module
class Module:
@T.prim_func
def dense_loop(
VAL: T.handle,
VEC: T.handle,
OUT: T.handle,
):
val = T.match_buffer(VAL, (37,), "float64")
vec = T.match_buffer(VEC, (11,), "float64")
out = T.match_buffer(OUT, (11,), "float64")
for j in T.serial(2):
for i in T.serial(2):
with T.block("db0"):
T.init()
out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
for j in T.serial(1):
for i in T.serial(2):
with T.block("db1"):
T.init()
out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
for j in T.serial(3):
for i in T.serial(3):
with T.block("db3"):
T.init()
out[i + 2] += val[6 + j * 3 + i] * vec[j + 2]
for j in T.serial(2):
for i in T.serial(1):
with T.block("db5"):
T.init()
out[i + 5] += val[15 + j * 1 + i] * vec[j + 0]
for j in T.serial(3):
for i in T.serial(1):
with T.block("db8"):
T.init()
out[i + 5] += val[21 + j * 1 + i] * vec[j + 6]
@R.function
def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), dtype="float64")):
cls = Module
out = R.call_tir(cls.dense_loop, (val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
return out
mod = Module
# The TIR function that will be tuned
dense_loop_tir = mod["dense_loop"]
target = tvm.target.Target("llvm -num-cores=1")
this_dir = os.path.dirname(os.path.abspath(__file__))
work_dir = os.path.join(this_dir, "tuning_logs")
# Tune the TIR function
database = ms.tir_integration.tune_tir(
mod=dense_loop_tir,
target=target,
work_dir=work_dir,
max_trials_global=64,
num_trials_per_iter=16,
)
if database is None:
raise ValueError("Database is None!")
# Compile the TIR function with the tuned database to a tir.Schedule
sch = ms.tir_integration.compile_tir(database, dense_loop_tir, target)
if sch is None:
print("No valid schedule found!")
else:
sch.mod.show()
# Replace the optimized TIR Prim func back into the original module.
# In `sch.mod`, the optimized TIR Prim func is stored as `sch.mod["main"]`.
optimized_tir = sch.mod["main"]
new_mod = tvm.IRModule({"dense_loop": optimized_tir, "main": mod["main"]})
new_mod.show()
# Build new module
new_mod = relax.transform.LegalizeOps()(new_mod)
ex = relax.build(new_mod, target=target)
vm = relax.VirtualMachine(ex, tvm.cpu())
# Prepare Data
val_np = np.random.rand(37).astype("float64")
vec_np = np.random.rand(11).astype("float64")
val_tvm = tvm.nd.array(val_np, device=tvm.cpu())
vec_tvm = tvm.nd.array(vec_np, device=tvm.cpu())
# Execute
output_tvm = vm["main"](val_tvm, vec_tvm)
# Output
output_np = output_tvm.numpy()
print("Output shape:", output_np.shape) # out_sinfo: (11,)
print("Output values:", output_np)
The output:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def dense_loop(val: T.Buffer((37,), "float64"), vec: T.Buffer((11,), "float64"), out: T.Buffer((11,), "float64")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db0"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 4)
T.reads(val[(j_i_fused_0 * 64 + j_i_fused_1) // 2 * 2 + (j_i_fused_0 * 64 + j_i_fused_1) % 2], vec[(j_i_fused_0 * 64 + j_i_fused_1) // 2])
T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
with T.init():
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] = out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[(j_i_fused_0 * 64 + j_i_fused_1) // 2 * 2 + (j_i_fused_0 * 64 + j_i_fused_1) % 2] * vec[(j_i_fused_0 * 64 + j_i_fused_1) // 2]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db1"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
T.reads(val[4 + T.Mul(0, 2) + (j_i_fused_0 * 64 + j_i_fused_1) % 2], vec[T.Add(0, 5)])
T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
with T.init():
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] = out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[4 + T.Mul(0, 2) + (j_i_fused_0 * 64 + j_i_fused_1) % 2] * vec[T.Add(0, 5)]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db3"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 9)
T.reads(val[6 + (j_i_fused_0 * 64 + j_i_fused_1) // 3 * 3 + (j_i_fused_0 * 64 + j_i_fused_1) % 3], vec[(j_i_fused_0 * 64 + j_i_fused_1) // 3 + 2])
T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2])
with T.init():
out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] = out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] + val[6 + (j_i_fused_0 * 64 + j_i_fused_1) // 3 * 3 + (j_i_fused_0 * 64 + j_i_fused_1) % 3] * vec[(j_i_fused_0 * 64 + j_i_fused_1) // 3 + 2]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db5"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
T.reads(val[T.Add(15 + (j_i_fused_0 * 64 + j_i_fused_1), 0)], vec[j_i_fused_0 * 64 + j_i_fused_1])
T.writes(out[T.Add(0, 5)])
with T.init():
out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(15 + (j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db8"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 3)
T.reads(val[T.Add(21 + (j_i_fused_0 * 64 + j_i_fused_1), 0)], vec[j_i_fused_0 * 64 + j_i_fused_1 + 6])
T.writes(out[T.Add(0, 5)])
with T.init():
out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(21 + (j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1 + 6]
T.evaluate(0)
@R.function
def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), dtype="float64")) -> R.Tensor((11,), dtype="float64"):
v = T.int64()
k = T.int64()
out = R.call_tir(dense_loop, (val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
return out
Output shape: (11,)
Output values: [1.56512320e-001 2.21357108e-001 1.09694842e+000 7.23903158e-001
1.02699930e+000 1.66537513e+000 0.00000000e+000 4.64222119e-310
4.24399158e-314 0.00000000e+000 0.00000000e+000]
As another optimization strategy, the Relax module can be tuned end-to-end using the methodology demonstrated here: