Hi, I tried tvm.te.scan
but it failed with error message TVMError: Unsupported op scan
. (commit 0e0adf51 2021/12/09)
I have also tried relay.testing
, but it failed too.(Cannot build LSTM model in relay.testing)
What is the appropriate way to apply auto-scheduling for an lstm operator? Any feedback on this would be greatly appreciated.
Here is my code and the lstm implementation are from tvm/apps/topi_recipe/rnn/lstm.py:
import tvm
from tvm import te, auto_scheduler, topi
@auto_scheduler.register_workload
def lstm(num_step = 128, num_input = 256, num_hidden = 512, batch_size=2):
# Global transition matrix
# Input hidden channel can be pre-caculated by a gemm
Xi2h = te.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
# Only handle hidden transition, saves space.
Wh2h = te.placeholder((4, num_hidden, num_hidden), name="Wh2h")
# h: output hidden state, c: cell state.
s_state_h = te.placeholder((num_step, batch_size, num_hidden))
s_state_c = te.placeholder((num_step, batch_size, num_hidden))
s_init_c = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c")
s_init_h = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h")
# LSTM transition
k = te.reduce_axis((0, num_hidden), name="ki2h")
s_h2h = te.compute(
(num_step, batch_size, 4, num_hidden),
lambda t, i, x, j: te.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
name="s_h2h",
)
# Gate rules
gates = te.compute(Xi2h.shape, lambda *i: Xi2h(*i) + s_h2h(*i), name="gates")
gshape = (num_step, batch_size, num_hidden)
in_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 0, j]), name="in_gate")
in_transform = te.compute(
gshape, lambda t, i, j: te.tanh(gates[t, i, 1, j]), name="in_transform"
)
forget_gate = te.compute(
gshape, lambda t, i, j: te.sigmoid(gates[t, i, 2, j]), name="forget_gate"
)
out_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 3, j]), name="out_gate")
next_c = te.compute(
gshape,
lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j]
+ in_gate[t, i, j] * in_transform[t, i, j],
name="next_c",
)
next_h = te.compute(
gshape, lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h"
)
update_c = te.compute(gshape, lambda *i: next_c(*i), name="update_c")
update_h = te.compute(gshape, lambda *i: next_h(*i), name="update_h")
# schedule
scan_h, scan_c = tvm.te.scan(
[s_init_h, s_init_c],
[update_h, update_c],
[s_state_h, s_state_c],
inputs=[Xi2h],
name="lstm_scan",
)
return [Xi2h, Wh2h, s_state_h, s_state_c, s_init_c, s_init_h, s_h2h, gates, forget_gate, out_gate, next_c, next_h, update_c, update_h, scan_h, scan_c]
def main():
target = tvm.target.Target("cuda")
seq_len = 32
in_dim = 512
hidden_dim = 1024
batch_size = 1
task = tvm.auto_scheduler.SearchTask(func=lstm,
args=(seq_len, in_dim, hidden_dim, batch_size),
target=target)
log_file = f'lstm_s{seq_len}_d{in_dim}_h{hidden_dim}.json'
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
task.tune(tune_option)
main()
And the error message is:
Traceback (most recent call last):
File "lstm_snippet.py", line 78, in <module>
main()
File "lstm_snippet.py", line 67, in main
target=target)
File "/tvm/python/tvm/auto_scheduler/search_task.py", line 447, in __init__
compute_dag = ComputeDAG(workload_key)
File "/tvm/python/tvm/auto_scheduler/compute_dag.py", line 124, in __init__
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
File "/tvm/python/tvm/_ffi/_ctypes/object.py", line 136, in __init_handle_by_constructor__
handle = __init_by_constructor__(fconstructor, args)
File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
4: TVMFuncCall
3: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::auto_scheduler::ComputeDAG (tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)#4}>(tvm::auto_scheduler::{lambda(tvm::runtime::Optional<tvm::runtime::Array<tvm::te::Tensor, void> >, tvm::runtime::Optional<tvm::te::Schedule>)#4}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
2: tvm::auto_scheduler::ComputeDAG::ComputeDAG(tvm::runtime::Array<tvm::te::Tensor, void>)
1: tvm::auto_scheduler::AccessAnalyzer::AccessAnalyzer(tvm::runtime::Array<tvm::te::Tensor, void> const&)
0: tvm::auto_scheduler::TopoSortOps(tvm::runtime::Array<tvm::te::Tensor, void> const&)
File "/tvm/src/auto_scheduler/compute_dag.cc", line 97
TVMError: Unsupported op scan(lstm_scan, 0x34f9100)