How to use Relay Control Flow?

Hi, Jiali and I tried to compile RNN-T PyTorch model by TVM. So we implement LSTM in /incubator-tvm/python/tvm/relay/frontend/pytorch.py

def _lstm():
    def _lstm_cell(input, hidden, params):
        hx = hidden[0]
        cx = hidden[1]
        _w_ih = params[0]
        _w_hh = params[1]
        _b_ih = params[2]
        _b_hh = params[3]
        i2h = _op.nn.bias_add(_op.nn.dense(input, _w_ih), _b_ih, axis=-1)
        h2h = _op.nn.bias_add(_op.nn.dense(hx, _w_hh), _b_hh, axis=-1)
        gates = i2h + h2h
        slice_gates = _op.split(gates, indices_or_sections=4, axis=1)

        in_gate = _activation_map["sigmoid"](slice_gates[0])  # (1, 1024)
        forget_gate = _activation_map["sigmoid"](slice_gates[1])
        cell_gate = _activation_map["tanh"](slice_gates[2])
        out_gate = _activation_map["sigmoid"](slice_gates[3])
        cy = forget_gate * cx + in_gate * cell_gate  # next_c
        hy = out_gate * _activation_map["tanh"](cy)  # next_h
        return [hy, cy]

    def gather_params(_params, has_biases):
        res = []
        if has_biases:
            assert len(_params) % 4 == 0  # "got an incorrect number of RNN parameters(bias)"
            for i in range(len(_params)):
                if i % 4 == 0:
                    res.append([_params[i], _params[i + 1], _params[i + 2], _params[i + 3]])
        else:
            assert len(_params) % 2 == 0  # "got an incorrect number of RNN parameters(no bias)"
            for i in range(len(_params)):
                if i % 2 == 0:
                    zero = _expr.const(0, dtype="int32")
                    res.append([_params[i], _params[i + 1], zero, zero])
        return res

    def unsqueeze_hidden(hiddens):
        return _op.transform.expand_dims(hiddens, int(0), 1)

    def full_layer(step_inputs, input_hidden, params, pre_compute_input=False):
        step_outputs = []
        hidden = input_hidden
        for input in step_inputs:
            hidden = _lstm_cell(input, hidden, params)
            step_outputs.append(hidden[0])

        hidden[0] = unsqueeze_hidden(hidden[0])
        hidden[1] = unsqueeze_hidden(hidden[1])
        return step_outputs, hidden

    def apply_layer_stack(input, hiddens, weights, num_layers):
        layer_input = input
        final_hiddens = []
        for i in range(num_layers):
            layer_output_outputs, layer_output_hidden = full_layer(layer_input, hiddens[i], weights[i])
            final_hiddens.append(layer_output_hidden)
            layer_input = layer_output_outputs

        layer_out = []
        for li in layer_input:
            layer_out.append(unsqueeze_hidden(li))
        return layer_out, final_hiddens  # final_hiddens: list

    def _lstm_impl(input, params, hx, cx, num_layers):
        hx_shape = _infer_shape(hx)
        cx_shape = _infer_shape(cx)
        layer_hx = _op.split(hx, hx_shape[0], axis=0)
        layer_cx = _op.split(hx, cx_shape[0], axis=0)

        total_layers = len(layer_hx)
        hiddens = []
        for i in range(total_layers):
            hiddens.append([_op.squeeze(layer_hx[i], axis=[0]), _op.squeeze(layer_cx[i], axis=[0])])

        res_output, res_hidden = apply_layer_stack(input, hiddens, params, num_layers)
        hy = []
        cy = []
        for hidden in res_hidden:
            hy.append(hidden[0])
            cy.append(hidden[1])

        hy_res = _op.concatenate(hy, 0)
        cy_res = _op.concatenate(cy, 0)
        res_output_res = _op.concatenate(res_output, 0)
        # print(_infer_shape(res_output_res))
        return res_output_res, hy_res, cy_res

    def _impl(inputs, input_types):
        _input = inputs[0]  # Tensor
        shape = _infer_shape(_input)
        temp_input = _op.split(_input, indices_or_sections=shape[0], axis=0)
        # print("factor: ", shape[0], "axis: ", 0)
        input_list = []
        for item in temp_input:
            input_list.append(_op.squeeze(item, axis=[0]))

        hx = inputs[1]  # TensorList
        _params = inputs[2]  # TensorList
        has_biases = inputs[3]  # bool
        num_layers = inputs[4]  # int64_t
        dropout_p = inputs[5]  # double
        train = inputs[6]  # bool
        bidirectional = inputs[7]  # bool
        batch_first = inputs[8]  # bool
        assert len(hx) == 2  # "lstm expects two hidden states"

        params = gather_params(_params, has_biases)
        results = _lstm_impl(input_list, params, hx[0], hx[1], num_layers)
        return results

    return _impl

But we found out that it will generate very long Relay expression and run slowly. merrymercy suggested us to use control flow except unfold LSTM cell. Could you give me some advice, or some examples? Thanks very much. @junrushao @MarisaKirisame @jroesch You can refer to this post: Auto-scheduling for lstm operator

can you send a PR to add your implementation of LSTM converter? This is a requested feature (see https://github.com/apache/incubator-tvm/issues/6474)

Unrolling is the standard way to implement lstm op conversion. Both MXNet and ONNX frontend do it. I don’t recommend pursuing the approach of control flow when the number of layers are known at compile time.

There is a usage of generating a while loop in PyTorch frontend, here https://github.com/apache/incubator-tvm/blob/main/python/tvm/relay/frontend/pytorch.py#L3222-L3223. You can use this to turn a static lstm op to a dynamic lstm. But the usage is complicated so I don’t recommend it.

Thanks! Current implementation of LSTM is not friendly since it generate very long Relay expression, and most importantly, it runs so slowly. I think it is a better way to implement LSTM in TOPI.

@jwfromm and I did ONNX LSTM a few months ago, and decided to unroll because the rest of the ONNX importer only support static shapes. We’ve recently fix that, and Josh recently implemented ONNX Loop using relay’s recursive while function. We could probably use that experience to go back and think about redoing LSTM.

1 Like

relay has support for recursion, pattern matching, conditional. with them you can define lstm.