I’m trying to auto-schedule a simple operator graph as follows (code below adapted from the auto-scheduling) tutorial.
import os
import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
@auto_scheduler.register_workload
def gemm_fused_layer(M, K, N):
w1 = te.placeholder((M, K), name="w1")
w2 = te.placeholder((M, K), name="w2")
m1 = te.placeholder((N, K), name="m1")
m2 = te.placeholder((N, K), name="m2")
k = te.reduce_axis((0, K), name="k")
mm1 = te.compute((M, N), lambda i, j: te.sum(w1[i, k]*m1[j, k], axis=k), name="mm1")
k = te.reduce_axis((0, K), name="k")
mm2 = te.compute((M, N), lambda i, j: te.sum(w2[i, k]*m2[j, k], axis=k), name="mm2")
o = te.compute((M, N), lambda i, j: mm1[i, j]*mm2[i, j], name="o")
return [w1, w2, m1, m2, mm1, mm2, o]
######################################################################
# Create the search task
target = tvm.target.Target("cuda")
M, K, N = 512, 512, 512
task = auto_scheduler.SearchTask(func=gemm_fused_layer, args=(M, K, N), target=target)
# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
log_file = "gemm_fused.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10, # change this to 1000 to achieve the best performance
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
task.tune(tune_option)
sch, args = task.apply_best(log_file)
# Kill the measurement process
del measure_ctx
func = tvm.build(sch, args, target)
# print(func.imported_modules[0].get_source())
# Evaluate execution time
device = tvm.gpu()
w1_np = np.random.uniform(size=(M, K)).astype(np.float32)
w2_np = np.random.uniform(size=(M, K)).astype(np.float32)
m1_np = np.random.uniform(size=(N, K)).astype(np.float32)
m2_np = np.random.uniform(size=(N, K)).astype(np.float32)
o_np = np.random.uniform(size=(M, N)).astype(np.float32)
w1_tvm = tvm.nd.empty(w1_np.shape, device=device)
w2_tvm = tvm.nd.empty(w2_np.shape, device=device)
m1_tvm = tvm.nd.empty(m1_np.shape, device=device)
m2_tvm = tvm.nd.empty(m2_np.shape, device=device)
o_tvm = tvm.nd.empty(o_np.shape, device=device)
evaluator = func.time_evaluator(func.entry_name, device, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(w1_tvm, w2_tvm, m1_tvm, m2_tvm, o_tvm).results) * 1000)
)
This throws up the following assertion:
Traceback (most recent call last):
File "tune_gemm_fused_cuda.py", line 47, in <module>
task.tune(tune_option)
File "/home/ppf/rnn_compilers/lowering/tvm/tvm/python/tvm/auto_scheduler/search_task.py", line 452, in tune
_ffi_api.AutoSchedule(search_policy, tuning_options)
File "/home/ppf/rnn_compilers/lowering/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
7: TVMFuncCall
6: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::runtime::ObjectRef, void> (tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}>(tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
5: tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)
4: tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)
3: tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)
2: tvm::auto_scheduler::SketchPolicyNode::GenerateSketches()
1: tvm::auto_scheduler::RuleMultiLevelTilingWithFusion::Apply(tvm::auto_scheduler::SketchPolicyNode const&, tvm::auto_scheduler::State const&, int) const
0: tvm::auto_scheduler::FollowTiling(tvm::auto_scheduler::State const&, int, std::vector<int, std::allocator<int> > const&, int)
File "../src/auto_scheduler/search_policy/utils.cc", line 289
TVMError:
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
Check failed: state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt == split_step_ids.size() (8 vs. 2) :
When I simplify the operator and remove the final addition, the auto-scheduling works fine. Is there anything I’m doing wrong?