Auto-scheduling for lstm operator

Thanks for your reply~ The reason of using large graph is that the build time is too long for small subgraph lstm computation function. The build time of the lstm computation declaration as small subgraph will be more than 3723 s. Because the relay graph for this computation declaration is so large(more than 20000 lines even more). And this is too time-consuming… The small subgraph computation code (which is developed by me) is here(tvm/relay/frontend/pytorch.py):

def _lstm():
    def lstm_cell(unbind_input, input_hidden, cell_param):
        '''
        unbind_input: 2D-Tensor
        input_hidden: tuple(2D-Tensor, 2D-Tensor)
        cell_param: A CellParams object
        return a tuple (2D-Tensor, 2D-Tensor)
        '''
        hx = input_hidden[0] # hx is a 2D tensor
        cx = input_hidden[1] # cx is a 2D tensor

        linear_ih = cell_param.linear_ih(unbind_input)
        linear_hh = cell_param.linear_hh(hx)
        gates = _op.add(linear_ih, linear_hh)
        chunked_gates = _op.split(gates, indices_or_sections=4, axis=1)
        assert(len(chunked_gates) == 4)
        in_gate = _op.sigmoid(chunked_gates[0])
        forget_gate = _op.sigmoid(chunked_gates[1])
        cell_gate = _op.tanh(chunked_gates[2])
        out_gate = _op.sigmoid(chunked_gates[3])
        cy = _op.add(_op.multiply(forget_gate, cx), _op.multiply(in_gate, cell_gate))
        hy = _op.multiply(out_gate, _op.tanh(cy))
        return hy, cy


    def full_layer(_input_unbind_list, input_hidden, cell_param):
        '''
        _input_unbind_list: A list of Tensor [(2D-Tensor), (2D-Tensor), ... , (2D-Tensor)]
        input_hidden: tuple(2D-Tensor, 2D-Tensor)
        cell_param: A CellParams object
        return step_outputs, hidden
        '''
        step_outputs = [] # step_outputs is a list of 2D-tensor [2D-tensor, 2D-tensor]
        hidden = input_hidden
        for i in range(len(_input_unbind_list)):
            hy, cy = lstm_cell(_input_unbind_list[i], hidden, cell_param)
            hidden = (hy, cy)
            step_outputs.append(hy)
        return step_outputs, hidden


    def apply_layer_stack(_input_unbind_list, hiddens, cell_param_list, num_layers):
        '''
        _input_unbind_list: A list of Tensor [[1,240], [1,240], ... , [1, 240]]
        hiddens is a list[tuple(2D-tensor, 2D-tensor), tuple(2D-tensor, 2D-tensor)]
        cell_param_list: a list of CellParams , its length is equal to num_layer
        num_layers: int
        return: layer_input_list is a 2D-tensor List, final_hiddens is a list of each element is (2D-tensor, 2D-tensor)
        '''
        assert(len(hiddens) == num_layers)
        assert(len(cell_param_list) == num_layers)
        layer_input_list = _input_unbind_list
        final_hiddens = []
        for i in range(num_layers):
            step_output_tensor_list, hidden = full_layer(layer_input_list, hiddens[i], cell_param_list[i])
            final_hiddens.append(hidden)
            layer_input_list = step_output_tensor_list
        return layer_input_list, final_hiddens


    def _lstm_impl(_input, cell_param_list, hx, cx, num_layers, dropout_p, train, bidirectional):
        '''
        _input: 3D Tensor [158,1,2048]
        cell_param_list: a list of CellParams , its length is equal to num_layer
        hx: a Tensor is a 3D-Tensor [2, 1, 1024]
        cx: a Tensor is a 3D-Tensor [2, 1, 1024]
        num_layer: int
        '''
        _input_unbind_list = unbind_func(_input)
        layer_hx = unbind_func(hx) # layer_hx is a list which includes a 2D tensor
        layer_cx = unbind_func(cx) # layer_cx is a list which includes a 2D tensor
        assert (len(layer_hx) == len(layer_cx))
        assert (len(cell_param_list) == len(layer_cx))
        assert (len(cell_param_list) == num_layers)
        total_layers = len(layer_hx)

        # hiddens is a list[(2D-tensor, 2D-tensor), (2D-tensor, 2D-tensor)]
        hiddens = []
        for i in range(total_layers):
            hiddens.append((layer_hx[i], layer_cx[i]))
        layer_output_list, final_hiddens = apply_layer_stack(_input_unbind_list, hiddens, cell_param_list, num_layers)
        layer_output = _op.stack(layer_output_list, axis=0)
        assert(len(final_hiddens) == num_layers)
        hy = []
        cy = []
        for i in range(len(final_hiddens)):
            hy.append(final_hiddens[i][0])
            cy.append(final_hiddens[i][1])
        hy_stack = _op.stack(hy, axis=0)
        cy_stack = _op.stack(cy, axis=0)
        return layer_output, hy_stack, cy_stack


    def _impl(inputs, input_types):
        _input = inputs[0]  # Tensor  3D-Tensor [316,1,240]
        hx = inputs[1]  # TensorList [(2,1,1024), (2,1,1025)] each tensor is a 3D-Tensor [2, 1, 1024]
        _params = inputs[2]  # TensorList
        has_bias = inputs[3]  # bool
        num_layers = inputs[4]  # int64_t
        dropout_p = inputs[5]  # double
        train = inputs[6]  # bool
        bidirectional = inputs[7]  # bool
        batch_first = inputs[8]  # bool
        assert len(hx) == 2  # "lstm expects two hidden states"
        cell_param_list = gather_params(_params, has_bias)
        results = _lstm_impl(_input, cell_param_list, hx[0], hx[1], num_layers, dropout_p, train, bidirectional)
        return results

    return _impl

In order to decrease the size of relay graph size and build time, I redefine the lstm computation declaration with large graph like this code:

def _lstm_new():
    def _lstm_impl(_input, cell_param_list, hx, cx, num_layers):
        '''
        _input: 3D Tensor [158,1,2048]
        cell_param_list: a list of CellParams , its length is equal to num_layer
        hx: a Tensor is a 3D-Tensor [2, 1, 1024]
        cx: a Tensor is a 3D-Tensor [2, 1, 1024]
        num_layer: int
        return:
        data: 3D Tensor
        final_hidden: a list of 2D-Tensor [hy, cy]
        '''

        layer_hx = unbind_func(hx)  # layer_hx is a list which includes a 2D tensor
        layer_cx = unbind_func(cx)  # layer_cx is a list which includes a 2D tensor
        assert (len(layer_hx) == len(layer_cx))
        assert (len(layer_hx) == num_layers)
        assert (len(cell_param_list) == num_layers)

        data = _input
        final_hiddens = []
        for i in range(num_layers):
            out_data = _op.nn.lstm_layer(data, layer_hx[i], layer_cx[i],
                                         cell_param_list[i].w_ih, cell_param_list[i].w_hh,
                                         cell_param_list[i].b_ih, cell_param_list[i].b_hh, num_layers)
            data = out_data
        return data, None, None

    def _impl(inputs, input_types):
        _input = inputs[0]  # Tensor  3D-Tensor [316,1,240]
        hx = inputs[1]  # TensorList [(2,1,1024), (2,1,1025)] each tensor is a 3D-Tensor [2, 1, 1024]
        _params = inputs[2]  # TensorList
        has_bias = inputs[3]  # bool
        num_layers = inputs[4]  # int64_t
        dropout_p = inputs[5]  # double
        train = inputs[6]  # bool
        bidirectional = inputs[7]  # bool
        batch_first = inputs[8]  # bool
        assert len(hx) == 2  # "lstm expects two hidden states"
        cell_param_list = gather_params(_params, has_bias)
        results = _lstm_impl(_input, cell_param_list, hx[0], hx[1], num_layers)
        return results

    return _impl

In this way, the relay graph will be small and build time can be speeded up three times than before. And next I want to use auto-schedule to generate an optimal schedule to complete the import of RNNT model in TVM. It seems that the large graph is necessary for me to define the computation delaeration for lstm op. So, could you please give me some guidance for next step? Thank you very much