Auto-scheduling for lstm operator

Hi, I tried tvm.te.scan but it failed with error message TVMError: Unsupported op scan. (commit 0e0adf51 2021/12/09)

I have also tried relay.testing, but it failed too.(Cannot build LSTM model in relay.testing)

What is the appropriate way to apply auto-scheduling for an lstm operator? Any feedback on this would be greatly appreciated.

Here is my code and the lstm implementation are from tvm/apps/topi_recipe/rnn/lstm.py:

import tvm
from tvm import te, auto_scheduler, topi


@auto_scheduler.register_workload
def lstm(num_step = 128, num_input = 256, num_hidden = 512, batch_size=2):

    # Global transition matrix
    # Input hidden channel can be pre-caculated by a gemm
    Xi2h = te.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
    # Only handle hidden transition, saves space.
    Wh2h = te.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = te.placeholder((num_step, batch_size, num_hidden))
    s_state_c = te.placeholder((num_step, batch_size, num_hidden))
    s_init_c = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c")
    s_init_h = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h")
    # LSTM transition
    k = te.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = te.compute(
        (num_step, batch_size, 4, num_hidden),
        lambda t, i, x, j: te.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
        name="s_h2h",
    )
    # Gate rules
    gates = te.compute(Xi2h.shape, lambda *i: Xi2h(*i) + s_h2h(*i), name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 0, j]), name="in_gate")
    in_transform = te.compute(
        gshape, lambda t, i, j: te.tanh(gates[t, i, 1, j]), name="in_transform"
    )
    forget_gate = te.compute(
        gshape, lambda t, i, j: te.sigmoid(gates[t, i, 2, j]), name="forget_gate"
    )
    out_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 3, j]), name="out_gate")
    next_c = te.compute(
        gshape,
        lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j]
        + in_gate[t, i, j] * in_transform[t, i, j],
        name="next_c",
    )
    next_h = te.compute(
        gshape, lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h"
    )
    update_c = te.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = te.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.te.scan(
        [s_init_h, s_init_c],
        [update_h, update_c],
        [s_state_h, s_state_c],
        inputs=[Xi2h],
        name="lstm_scan",
    )
    return [Xi2h, Wh2h, s_state_h, s_state_c, s_init_c, s_init_h, s_h2h, gates, forget_gate, out_gate, next_c, next_h, update_c, update_h, scan_h, scan_c]


def main():
    target = tvm.target.Target("cuda")

    seq_len = 32
    in_dim = 512
    hidden_dim = 1024
    batch_size = 1
    task = tvm.auto_scheduler.SearchTask(func=lstm,
                                         args=(seq_len, in_dim, hidden_dim, batch_size),
                                         target=target)
    log_file = f'lstm_s{seq_len}_d{in_dim}_h{hidden_dim}.json'

    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=10,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        verbose=2,
    )
    task.tune(tune_option)


main()

And the error message is:

Traceback (most recent call last):
  File "lstm_snippet.py", line 78, in <module>
    main()
  File "lstm_snippet.py", line 67, in main
    target=target)
  File "/tvm/python/tvm/auto_scheduler/search_task.py", line 447, in __init__
    compute_dag = ComputeDAG(workload_key)
  File "/tvm/python/tvm/auto_scheduler/compute_dag.py", line 124, in __init__
    self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
  File "/tvm/python/tvm/_ffi/_ctypes/object.py", line 136, in __init_handle_by_constructor__
    handle = __init_by_constructor__(fconstructor, args)
  File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  4: TVMFuncCall
  3: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::auto_scheduler::ComputeDAG (tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)#4}>(tvm::auto_scheduler::{lambda(tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)#4}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  2: tvm::auto_scheduler::ComputeDAG::ComputeDAG(tvm::runtime::Array<tvm::te::Tensor, void>)
  1: tvm::auto_scheduler::AccessAnalyzer::AccessAnalyzer(tvm::runtime::Array<tvm::te::Tensor, void> const&)
  0: tvm::auto_scheduler::TopoSortOps(tvm::runtime::Array<tvm::te::Tensor, void> const&)
  File "/tvm/src/auto_scheduler/compute_dag.cc", line 97
TVMError: Unsupported op scan(lstm_scan, 0x34f9100)