Sparse matrix multiplication's sketch rule [trouble shooting]

Hi, I am trying to run the auto scheduling for Sparse-Dense matrix multiplication. However, I run into troubles & following errors.

In general, I follow the tutorial. I try to supply the sparse matrix data for measurement during the tuning. I use CSR format to represent the sparse matrix and define the computation rule (not BSR format used in the tutorial). On the other hand, I don’t specify the custom sketch rule. From my opinion, the only significant difference between my implementation and the tutorial is that I did not supply a custom sketch rule. So, I am speculating that this can be the possible source of error.

So, here is my questions

(1) How should I solve this error? Does it relate to missing custom sketches? If yes, does that mean the only studying material of custom sketch is in the tutorial you mentioned? If no, what are the possible reasons?

(2) Given the custom sketch rule is hard to write well and sparse matrix multiplication is an important common operator, I am wondering is it possible that we do auto scheduling for Sparse matrix multiplication without a custom sketch rule. That is, the only difference between Spare Dense Matrix Multiplication and Dense GEMM is the definition of compute and the data format we supplied and the overall flow is similar to this tutorial.

(3) The TOPI is used in the tutorial. So, I also used it in the compute rule definition. However, base on answer to this Q&A, I think TOPI is used as template for autoTVM to tune the schedule, instead of used in auto scheduling. So, can I think TOPI is aiming to reduce the effort in wiring low level tensor expression, but we still need to write the autoTVM/auto scheduler to start searching process?

I attached my code and error message below.

Here is my code:

import tvm
import os
import numpy as np
import tvm.testing
from tvm import te, auto_scheduler, runtime, topi
from tvm.auto_scheduler import _ffi_api
from tvm.topi.utils import get_const_tuple
from tvm.topi.sparse.utils import random_bsr_matrix
import scipy.sparse as sp

# define the compute rule
@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def sparse_dense_gemm(M, K, N, w_data_shape, w_indices_shape, w_indptr_shape, dtype):
    X = te.placeholder((M, K), name="X", dtype=dtype)
    W_data = te.placeholder(shape=w_data_shape, dtype=dtype)
    W_indices = te.placeholder(shape=w_indices_shape, dtype="int32")
    W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32")
    out = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
    return [X, W_data, W_indices, W_indptr, out]

M=986 
N=986 
K=986
sparsity=0.8
target = tvm.target.Target("llvm -mcpu=skylake-avx512")
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = sp.random(N, K, density=sparsity, format="csr", dtype="float32")
W_np = W_sp_np.todense()
Y_np = X_np @ W_np.T    # Process the matrix multiplication

prefix = "sparse_dense_csr_%d_%d_%d_%d_" % (
    N,
    K,
    W_sp_np.indices.shape[0],
    W_sp_np.indptr.shape[0],
)
task = tvm.auto_scheduler.SearchTask(func=sparse_dense_gemm,
    args=(M, K, N, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), 
    target=target,
    task_inputs={
        prefix + "W_data": runtime.ndarray.array(W_sp_np.data),
        prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices),
        prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr),
    },
        task_inputs_save_to_file=True,)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
log_file="test.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1000,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=0,
    runner = auto_scheduler.LocalRunner(timeout=100),
)
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
func = tvm.build(sch, args, target)
dev = tvm.cpu()

X_tvm = tvm.nd.array(X_np, device=dev)
W_data_tvm = tvm.nd.array(W_sp_np.data, device=dev)
W_indices_tvm = tvm.nd.array(W_sp_np.indices, device=dev)
W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, device=dev)
Y_tvm = tvm.nd.empty(Y_np.shape, device=dev)

func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, Y_tvm)

# Check results
np.testing.assert_allclose(Y_np, Y_tvm.numpy(), atol=1e-4, rtol=1e-4)
# print("check result done")
# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, Y_tvm).results) * 1000)
)

Here is the error message:

Computational DAG:
placeholder = PLACEHOLDER [987]
placeholder = PLACEHOLDER [777756]
placeholder = PLACEHOLDER [777756]
X = PLACEHOLDER [986, 986]
compute(i, row) += (placeholder[(placeholder[row] + elem_idx)]*X[i, placeholder[(placeholder[row] + elem_idx)]])

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Traceback (most recent call last):
  File "auto_scheduler_sparse_gemm_cpu.py", line 173, in <module>
    task.tune(tune_option)
File "/home/Software/tvm/python/tvm/auto_scheduler/search_task.py", line 498, in tune
    _ffi_api.AutoSchedule(search_policy, tuning_options)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
 
  6: TVMFuncCall
  5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::runtime::ObjectRef, void> (tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}>(tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  4: tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)
  3: tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)
  2: tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)
  1: tvm::auto_scheduler::SketchPolicyNode::SampleInitPopulation(tvm::runtime::Array<tvm::auto_scheduler::State, void> const&)
  0: tvm::support::parallel_for(int, int, std::function<void (int)> const&, int, std::function<std::vector<std::vector<int, std::allocator<int> >, std::allocator<std::vector<int, std::allocator<int> > > > (int, int, int, int)>)
  10: 0xffffffffffffffff
  9: __clone
  8: start_thread
        at /build/glibc-uZu3wS/glibc-2.27/nptl/pthread_create.c:463
  7: 0x00007ff2ed6196de
  6: std::thread::_State_impl<std::thread::_Invoker<std::tuple<std::packaged_task<void (std::vector<int, std::allocator<int> > const&, std::function<void (int)> const&)>, std::vector<int, std::allocator<int> >, std::function<void (int)> > > >::_M_run()
  5: void std::call_once<void (std::__future_base::_State_baseV2::*)(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*), std::__future_base::_State_baseV2*, std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*>(std::once_flag&, void (std::__future_base::_State_baseV2::*&&)(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*), std::__future_base::_State_baseV2*&&, std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*&&, bool*&&)
  4: __pthread_once_slow
        at /build/glibc-uZu3wS/glibc-2.27/nptl/pthread_once.c:116
  3: std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*)
  2: _ZNSt17_Function_handlerIFSt10unique_ptrINSt13__future_base12_Result_baseENS2_8_DeleterEEvENS1_12_Task_setterIS0_INS1_7_ResultIvEES3_EZNS1_11_Task_stateIZN3tvm7support12parallel_forEiiRKSt8functionIFviEEiSD_IFSt6vectorISI_IiSaIiEESaISK_EEiiiiEEEUlRKSK_SH_E_SJ_FvSQ_SH_EE6_M_runESQ_SH_EUlvE_vEEE9_M_in
  1: std::_Function_handler<void (int), tvm::auto_scheduler::SketchPolicyNode::SampleInitPopulation(tvm::runtime::Array<tvm::auto_scheduler::State, void> const&)::{lambda(int)#1}>::_M_invoke(std::_Any_data const&, int&&)
  0: tvm::auto_scheduler::InitFillTileSize::Apply(tvm::auto_scheduler::SketchPolicyNode*, tvm::auto_scheduler::State*, std::mersenne_twister_engine<unsigned long, 32ul, 624ul, 397ul, 31ul, 2567483615ul, 11ul, 4294967295ul, 7ul, 2636928640ul, 15ul, 4022730752ul, 18ul, 1812433253ul>*) const

File "/home/Software/tvm/src/support/parallel_for.cc", line 92
TVMError: Parallel_for error with [14:39:10] /home/Software/tvm/src/auto_scheduler/search_policy/sketch_policy_rules.cc:515:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (ps->extent) is false:

I also tried to define the search policy as follow, but is does not work either. The error message is the same as above.

search_policy = auto_scheduler.SketchPolicy(
                    task,
                    program_cost_model=auto_scheduler.XGBModel(),)
task.tune(tune_option, search_policy)

Thanks for your time!

Hi, I got the same error message. Did you fix this bug?

Hi, unfortunately, I did not find a way to solve it. Please let me know if you have any idea on tackling this issue.

I just found where I did wrong. In meet_condition_func, I missed the correct tag of op which i wanna tune.But I didn’t see your code of meet_condition_func.Maybe you could check it.Hope this helps:)