Auto-scheduling for lstm operator

Hi Experts, @Lianminzheng @jcf94 I tried to define a new operator for lstm network.The computation declaration for lstm op has been tested and it is correct. Now I want to use auto-scheduling to automatically generate a large search space and find a good schedule in the space. But it can not generate the schedule successfully, my code is here:

from tvm import topi

def unbind_func(data):
    input_list = topi.split(data, indices_or_sections=data.shape[0].value, axis=0)
    input_sq_list = []
    for item in input_list:
        input_sq = topi.squeeze(item, axis=0)
        input_sq_list.append(input_sq)
    return input_sq_list

def lstm_layer(data, hx, cx, w_ih, w_hh, b_ih, b_hh, out_dtype=None):
    """The default implementation of lstm_layer in topi.

    Parameters
    ----------
    data : tvm.te.Tensor
        3-D with shape [x, y, z]

    hx : tvm.te.Tensor
        2-D with shape [a, b]

    cx : tvm.te.Tensor
        2-D with shape [a, b]

    w_ih : tvm.te.Tensor
        2-D with shape

    w_hh : tvm.te.Tensor
        2-D with shape

    b_ih : tvm.te.Tensor
        1-D with shape

    b_hh : tvm.te.Tensor
        1-D with shape

    out_dtype : str
        The output type. This is used for mixed precision.

    Returns
    -------
    output : tvm.te.Tensor
        3-D with shape
    hy: tvm.te.Tensor
        3-D with shape
    cy: tvm.te.Tensor
        3-D with shape
    """
    assert len(data.shape) == 3 and len(hx.shape) == 2 and len(cx.shape) == 2 and len(w_ih.shape) == 2 \
           and len(w_hh.shape) == 2 and len(b_ih.shape) == 1 and len(b_hh.shape) == 1, "only support 2-dim dense"

    if out_dtype is None:
        out_dtype = data.dtype

    # unbind input data
    input_list = unbind_func(data)
    step_outputs = []
    for input in input_list:
        """input is 2D tensor"""
        linear_ih = topi.nn.dense(input, w_ih, b_ih)
        linear_hh = topi.nn.dense(hx, w_hh, b_hh)
        gates = topi.add(linear_ih, linear_hh)
        chunked_gates = topi.split(gates, indices_or_sections=4, axis=1)
        assert (len(chunked_gates) == 4)
        in_gate = topi.sigmoid(chunked_gates[0])
        forget_gate = topi.sigmoid(chunked_gates[1])
        cell_gate = topi.tanh(chunked_gates[2])
        out_gate = topi.sigmoid(chunked_gates[3])
        cy = topi.add(topi.multiply(forget_gate, cx), topi.multiply(in_gate, cell_gate))
        hy = topi.multiply(out_gate, topi.tanh(cy))

        step_outputs.append(hy)
        hx = hy
        cx = cy
    output = topi.stack(step_outputs, axis=0)
    return output

import tvm
from tvm import te, auto_scheduler, topi

@auto_scheduler.register_workload
def lstm_layers(hx, cx, w_ih, w_hh, b_ih, b_hh):
    data = te.placeholder((2, 1, 240), name="data")
    out = topi.nn.lstm_layer(data, hx, cx, w_ih, w_hh, b_ih, b_hh, out_dtype="float32")

    return [data, hx, cx, w_ih, w_hh, b_ih, b_hh, out]

target = tvm.target.Target("cuda")

# the layer in lstm
hx = te.placeholder((1, 1024), name='hx')
cx = te.placeholder((1, 1024), name='cx')
w_ih = te.placeholder((4096, 240), name='w_ih')
w_hh = te.placeholder((4096, 1024), name='w_hh')
b_ih = te.placeholder((4096,), name='b_ih')
b_hh = te.placeholder((4096,), name='b_hh')
task = auto_scheduler.create_task(lstm_layers, (hx, cx, w_ih, w_hh, b_ih, b_hh), target)

# Inspect the computational graph
print(task.compute_dag)

measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1,
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile("lstm_layers.json")],
)

sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)

print(tvm.lower(sch, list(args), simple_mode=True))

Only the data = te.placeholder((1, 1, 240), name=“data”), the schedule can be generated successfully, when data = te.placeholder((?, 1, 240), name=“data”)(and ?>1), the DAGgraph can be obtained and it shows “Get devices for measurement successfully!” , but the schedule can not generate successfully. The elaborate error is following:

Traceback (most recent call last):
  File "/root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/tutorials/auto_scheduler/tune_lstm_layers.py", line 109, in <module>
    sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
  File "/root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/python/tvm/auto_scheduler/auto_schedule.py", line 213, in auto_schedule
    sch, tensors = _ffi_api.AutoSchedule(search_policy, tuning_options)
  File "/root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (7) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(TVMFuncCall+0x61) [0x7fa9afe26ec1]
  [bt] (6) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(+0xacaacd) [0x7fa9af1cdacd]
  [bt] (5) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)+0x116) [0x7fa9af1cd1b6]
  [bt] (4) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)+0xa82) [0x7fa9af262f52]
  [bt] (3) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)+0x1c3) [0x7fa9af261f83]
  [bt] (2) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(tvm::auto_scheduler::SketchPolicyNode::SampleInitPopulation(tvm::runtime::Array<tvm::auto_scheduler::State, void> const&, int)+0x21e) [0x7fa9af25d39e]
  [bt] (1) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(tvm::support::parallel_for(int, int, std::function<void (int)> const&, int, std::function<std::vector<std::vector<int, std::allocator<int> >, std::allocator<std::vector<int, std::allocator<int> > > > (int, int, int, int)>)+0x1273) [0x7fa9af7fb413]
  [bt] (0) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x5f) [0x7fa9af1d171f]
  [bt] (8) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xbd6df) [0x7fa9ab9976df]
  [bt] (7) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(std::thread::_State_impl<std::_Bind_simple<std::packaged_task<void (std::vector<int, std::allocator<int> > const&, std::function<void (int)> const&)> (std::vector<int, std::allocator<int> >, std::function<void (int)>)> >::_M_run()+0xd3) [0x7fa9af7fbb13]
  [bt] (6) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(void std::call_once<void (std::__future_base::_State_baseV2::*)(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*), std::__future_base::_State_baseV2*, std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*>(std::once_flag&, void (std::__future_base::_State_baseV2::*&&)(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*), std::__future_base::_State_baseV2*&&, std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*&&, bool*&&)+0x71) [0x7fa9af7fba01]
  [bt] (5) /lib/x86_64-linux-gnu/libpthread.so.0(+0xf827) [0x7fa9d5c09827]
  [bt] (4) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*)+0x29) [0x7fa9af7fb8e9]
  [bt] (3) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(+0x10f6072) [0x7fa9af7f9072]
  [bt] (2) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(+0xb5a0f2) [0x7fa9af25d0f2]
  [bt] (1) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(tvm::auto_scheduler::InitThreadBind::Apply(tvm::auto_scheduler::SketchPolicyNode*, tvm::auto_scheduler::State*, std::mersenne_twister_engine<unsigned long, 32ul, 624ul, 397ul, 31ul, 2567483615ul, 11ul, 4294967295ul, 7ul, 2636928640ul, 15ul, 4022730752ul, 18ul, 1812433253ul>*) const+0x2f9) [0x7fa9af2743e9]
  [bt] (0) /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/build/libtvm.so(+0xb671ff) [0x7fa9af26a1ff]
  File "/root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/src/support/parallel_for.cc", line 92
TVMError: Parallel_for error with [21:43:57] /root/test6_root/jialipang/TVM-Git-1009/incubator-tvm-master/src/auto_scheduler/search_policy/sketch_policy_rules.cc:710: Check failed: HasCrossThreadReduction(*state, stage_id): 

Can you please help me to understand and fix the above issue.

Thanks

Thanks! We have not tried such ops before, this case seems interesting.

I’ll try your code and figure out where the possible bug occurs.

1 Like

Thanks for reporting. This PR (https://github.com/apache/incubator-tvm/pull/6683) fixed this bug.

However, typically for large graphs like this. We don’t treat them as a single search task. Treating it as a single large graph makes the search very inefficient.

For large graphs, we use Relay to partition the large graphs into several small subgraphs and treat a small subgraph as a search task. We then use the auto-scheduler to optimize every subgraph and use a task scheduler to schedule search tasks.

The upstream of relay integration and task scheduler (PR) is still work in progress. Hopefully, we can finish the relay integration with tutorials in one or two weeks.

1 Like

Thanks for your reply~ The reason of using large graph is that the build time is too long for small subgraph lstm computation function. The build time of the lstm computation declaration as small subgraph will be more than 3723 s. Because the relay graph for this computation declaration is so large(more than 20000 lines even more). And this is too time-consuming… The small subgraph computation code (which is developed by me) is here(tvm/relay/frontend/pytorch.py):

def _lstm():
    def lstm_cell(unbind_input, input_hidden, cell_param):
        '''
        unbind_input: 2D-Tensor
        input_hidden: tuple(2D-Tensor, 2D-Tensor)
        cell_param: A CellParams object
        return a tuple (2D-Tensor, 2D-Tensor)
        '''
        hx = input_hidden[0] # hx is a 2D tensor
        cx = input_hidden[1] # cx is a 2D tensor

        linear_ih = cell_param.linear_ih(unbind_input)
        linear_hh = cell_param.linear_hh(hx)
        gates = _op.add(linear_ih, linear_hh)
        chunked_gates = _op.split(gates, indices_or_sections=4, axis=1)
        assert(len(chunked_gates) == 4)
        in_gate = _op.sigmoid(chunked_gates[0])
        forget_gate = _op.sigmoid(chunked_gates[1])
        cell_gate = _op.tanh(chunked_gates[2])
        out_gate = _op.sigmoid(chunked_gates[3])
        cy = _op.add(_op.multiply(forget_gate, cx), _op.multiply(in_gate, cell_gate))
        hy = _op.multiply(out_gate, _op.tanh(cy))
        return hy, cy


    def full_layer(_input_unbind_list, input_hidden, cell_param):
        '''
        _input_unbind_list: A list of Tensor [(2D-Tensor), (2D-Tensor), ... , (2D-Tensor)]
        input_hidden: tuple(2D-Tensor, 2D-Tensor)
        cell_param: A CellParams object
        return step_outputs, hidden
        '''
        step_outputs = [] # step_outputs is a list of 2D-tensor [2D-tensor, 2D-tensor]
        hidden = input_hidden
        for i in range(len(_input_unbind_list)):
            hy, cy = lstm_cell(_input_unbind_list[i], hidden, cell_param)
            hidden = (hy, cy)
            step_outputs.append(hy)
        return step_outputs, hidden


    def apply_layer_stack(_input_unbind_list, hiddens, cell_param_list, num_layers):
        '''
        _input_unbind_list: A list of Tensor [[1,240], [1,240], ... , [1, 240]]
        hiddens is a list[tuple(2D-tensor, 2D-tensor), tuple(2D-tensor, 2D-tensor)]
        cell_param_list: a list of CellParams , its length is equal to num_layer
        num_layers: int
        return: layer_input_list is a 2D-tensor List, final_hiddens is a list of each element is (2D-tensor, 2D-tensor)
        '''
        assert(len(hiddens) == num_layers)
        assert(len(cell_param_list) == num_layers)
        layer_input_list = _input_unbind_list
        final_hiddens = []
        for i in range(num_layers):
            step_output_tensor_list, hidden = full_layer(layer_input_list, hiddens[i], cell_param_list[i])
            final_hiddens.append(hidden)
            layer_input_list = step_output_tensor_list
        return layer_input_list, final_hiddens


    def _lstm_impl(_input, cell_param_list, hx, cx, num_layers, dropout_p, train, bidirectional):
        '''
        _input: 3D Tensor [158,1,2048]
        cell_param_list: a list of CellParams , its length is equal to num_layer
        hx: a Tensor is a 3D-Tensor [2, 1, 1024]
        cx: a Tensor is a 3D-Tensor [2, 1, 1024]
        num_layer: int
        '''
        _input_unbind_list = unbind_func(_input)
        layer_hx = unbind_func(hx) # layer_hx is a list which includes a 2D tensor
        layer_cx = unbind_func(cx) # layer_cx is a list which includes a 2D tensor
        assert (len(layer_hx) == len(layer_cx))
        assert (len(cell_param_list) == len(layer_cx))
        assert (len(cell_param_list) == num_layers)
        total_layers = len(layer_hx)

        # hiddens is a list[(2D-tensor, 2D-tensor), (2D-tensor, 2D-tensor)]
        hiddens = []
        for i in range(total_layers):
            hiddens.append((layer_hx[i], layer_cx[i]))
        layer_output_list, final_hiddens = apply_layer_stack(_input_unbind_list, hiddens, cell_param_list, num_layers)
        layer_output = _op.stack(layer_output_list, axis=0)
        assert(len(final_hiddens) == num_layers)
        hy = []
        cy = []
        for i in range(len(final_hiddens)):
            hy.append(final_hiddens[i][0])
            cy.append(final_hiddens[i][1])
        hy_stack = _op.stack(hy, axis=0)
        cy_stack = _op.stack(cy, axis=0)
        return layer_output, hy_stack, cy_stack


    def _impl(inputs, input_types):
        _input = inputs[0]  # Tensor  3D-Tensor [316,1,240]
        hx = inputs[1]  # TensorList [(2,1,1024), (2,1,1025)] each tensor is a 3D-Tensor [2, 1, 1024]
        _params = inputs[2]  # TensorList
        has_bias = inputs[3]  # bool
        num_layers = inputs[4]  # int64_t
        dropout_p = inputs[5]  # double
        train = inputs[6]  # bool
        bidirectional = inputs[7]  # bool
        batch_first = inputs[8]  # bool
        assert len(hx) == 2  # "lstm expects two hidden states"
        cell_param_list = gather_params(_params, has_bias)
        results = _lstm_impl(_input, cell_param_list, hx[0], hx[1], num_layers, dropout_p, train, bidirectional)
        return results

    return _impl

In order to decrease the size of relay graph size and build time, I redefine the lstm computation declaration with large graph like this code:

def _lstm_new():
    def _lstm_impl(_input, cell_param_list, hx, cx, num_layers):
        '''
        _input: 3D Tensor [158,1,2048]
        cell_param_list: a list of CellParams , its length is equal to num_layer
        hx: a Tensor is a 3D-Tensor [2, 1, 1024]
        cx: a Tensor is a 3D-Tensor [2, 1, 1024]
        num_layer: int
        return:
        data: 3D Tensor
        final_hidden: a list of 2D-Tensor [hy, cy]
        '''

        layer_hx = unbind_func(hx)  # layer_hx is a list which includes a 2D tensor
        layer_cx = unbind_func(cx)  # layer_cx is a list which includes a 2D tensor
        assert (len(layer_hx) == len(layer_cx))
        assert (len(layer_hx) == num_layers)
        assert (len(cell_param_list) == num_layers)

        data = _input
        final_hiddens = []
        for i in range(num_layers):
            out_data = _op.nn.lstm_layer(data, layer_hx[i], layer_cx[i],
                                         cell_param_list[i].w_ih, cell_param_list[i].w_hh,
                                         cell_param_list[i].b_ih, cell_param_list[i].b_hh, num_layers)
            data = out_data
        return data, None, None

    def _impl(inputs, input_types):
        _input = inputs[0]  # Tensor  3D-Tensor [316,1,240]
        hx = inputs[1]  # TensorList [(2,1,1024), (2,1,1025)] each tensor is a 3D-Tensor [2, 1, 1024]
        _params = inputs[2]  # TensorList
        has_bias = inputs[3]  # bool
        num_layers = inputs[4]  # int64_t
        dropout_p = inputs[5]  # double
        train = inputs[6]  # bool
        bidirectional = inputs[7]  # bool
        batch_first = inputs[8]  # bool
        assert len(hx) == 2  # "lstm expects two hidden states"
        cell_param_list = gather_params(_params, has_bias)
        results = _lstm_impl(_input, cell_param_list, hx[0], hx[1], num_layers)
        return results

    return _impl

In this way, the relay graph will be small and build time can be speeded up three times than before. And next I want to use auto-schedule to generate an optimal schedule to complete the import of RNNT model in TVM. It seems that the large graph is necessary for me to define the computation delaeration for lstm op. So, could you please give me some guidance for next step? Thank you very much

Thanks for your fix. And I run the fixed code, the above error has been fixed~ For the input shape as (2,1,240), the schedule can be generated. But for the data input shape as (3,1,240), there is another error : Process finished with exit code 136 (interrupted by signal 8: SIGFPE.

-------------------------  [ Search ]
------------------------------------------------------------
Generate Sketches		#s: 1
Sample Initial Population	#s: 0	fail_ct: 2048	Time elapsed: 17.35

Process finished with exit code 136 (interrupted by signal 8: SIGFPE)

I don’t know how to fit the auto-schedule in the real input shape as (300,1,240)… Looking forward to your advice~

@Jiali

For the problem of the slow compilation of LSTM, you can open another post for it. I imagine that you are unfolding the LSTM cells so the relay program is very long. Since relay supports control flow, you do not need to unfold all of them. You can open another post to ask how to do this. (cc @junrushao @MarisaKirisame @jroesch).

For the auto-scheduler error you get, we are aware of this. @comaniac is working to fix this problem (i.e. the number of valid initial population is zero).

I do not recommend you to use the auto-scheduler in this way. You should build a relay program with control flow, and then use the relay integration of auto-scheduler (which will be upstreamed later).

@merrymercy Alternatively, if we are applying LSTM to a fixed length sequence, we can use tvm.scan. This requires that we deal with the IterVarType correctly.

@Jiali this could be due to the fact that the ratio of valid schedules over an entire search space is relatively low, so population sampling cannot find any valid schedule. If the evolutionary search is based on a set of invalid schedules, it is highly possible that it will be trapped in invalid schedules and will need lots of time to find the first valid one.

I have a local branch that solves this problem by making sure evolutionary search starts with a set of valid schedules. However, this change will increase the sampling time from seconds to minutes, so I need to discuss with other folks before committing.

1 Like

Hi, I tried tvm.te.scan but it failed with error message TVMError: Unsupported op scan. (commit 0e0adf51 2021/12/09)

I have also tried relay.testing, but it failed too.(Cannot build LSTM model in relay.testing)

What is the appropriate way to apply auto-scheduling for an lstm operator? Any feedback on this would be greatly appreciated.

Here is my code and the lstm implementation are from tvm/apps/topi_recipe/rnn/lstm.py:

import tvm
from tvm import te, auto_scheduler, topi


@auto_scheduler.register_workload
def lstm(num_step = 128, num_input = 256, num_hidden = 512, batch_size=2):

    # Global transition matrix
    # Input hidden channel can be pre-caculated by a gemm
    Xi2h = te.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
    # Only handle hidden transition, saves space.
    Wh2h = te.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = te.placeholder((num_step, batch_size, num_hidden))
    s_state_c = te.placeholder((num_step, batch_size, num_hidden))
    s_init_c = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c")
    s_init_h = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h")
    # LSTM transition
    k = te.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = te.compute(
        (num_step, batch_size, 4, num_hidden),
        lambda t, i, x, j: te.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
        name="s_h2h",
    )
    # Gate rules
    gates = te.compute(Xi2h.shape, lambda *i: Xi2h(*i) + s_h2h(*i), name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 0, j]), name="in_gate")
    in_transform = te.compute(
        gshape, lambda t, i, j: te.tanh(gates[t, i, 1, j]), name="in_transform"
    )
    forget_gate = te.compute(
        gshape, lambda t, i, j: te.sigmoid(gates[t, i, 2, j]), name="forget_gate"
    )
    out_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 3, j]), name="out_gate")
    next_c = te.compute(
        gshape,
        lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j]
        + in_gate[t, i, j] * in_transform[t, i, j],
        name="next_c",
    )
    next_h = te.compute(
        gshape, lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h"
    )
    update_c = te.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = te.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.te.scan(
        [s_init_h, s_init_c],
        [update_h, update_c],
        [s_state_h, s_state_c],
        inputs=[Xi2h],
        name="lstm_scan",
    )
    return [Xi2h, Wh2h, s_state_h, s_state_c, s_init_c, s_init_h, s_h2h, gates, forget_gate, out_gate, next_c, next_h, update_c, update_h, scan_h, scan_c]


def main():
    target = tvm.target.Target("cuda")

    seq_len = 32
    in_dim = 512
    hidden_dim = 1024
    batch_size = 1
    task = tvm.auto_scheduler.SearchTask(func=lstm,
                                         args=(seq_len, in_dim, hidden_dim, batch_size),
                                         target=target)
    log_file = f'lstm_s{seq_len}_d{in_dim}_h{hidden_dim}.json'

    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=10,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        verbose=2,
    )
    task.tune(tune_option)


main()

And the error message is:

Traceback (most recent call last):
  File "lstm_snippet.py", line 78, in <module>
    main()
  File "lstm_snippet.py", line 67, in main
    target=target)
  File "/tvm/python/tvm/auto_scheduler/search_task.py", line 447, in __init__
    compute_dag = ComputeDAG(workload_key)
  File "/tvm/python/tvm/auto_scheduler/compute_dag.py", line 124, in __init__
    self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
  File "/tvm/python/tvm/_ffi/_ctypes/object.py", line 136, in __init_handle_by_constructor__
    handle = __init_by_constructor__(fconstructor, args)
  File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  4: TVMFuncCall
  3: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::auto_scheduler::ComputeDAG (tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)#4}>(tvm::auto_scheduler::{lambda(tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)#4}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  2: tvm::auto_scheduler::ComputeDAG::ComputeDAG(tvm::runtime::Array<tvm::te::Tensor, void>)
  1: tvm::auto_scheduler::AccessAnalyzer::AccessAnalyzer(tvm::runtime::Array<tvm::te::Tensor, void> const&)
  0: tvm::auto_scheduler::TopoSortOps(tvm::runtime::Array<tvm::te::Tensor, void> const&)
  File "/tvm/src/auto_scheduler/compute_dag.cc", line 97
TVMError: Unsupported op scan(lstm_scan, 0x34f9100)