Convert PyTorch LSTM model to TVM via onnx

Hi, I am struggling for several hours with the following issue:

I’ve got a lstm model in pytorch that I want to convert to TVM. Here is the sample code of the model. It includes one lstm layer. From what I’ve found until now, TVM does not support yet LSTM operators if converting from pytorch directly. Therefore I’ve tried to convert my model first to ONNX and then convert it to TVM, but the conversion doesn’t work well.

Here is the code.

    import torch
    import torch.nn as nn
    import numpy as np
    import onnxruntime as ort
    import onnx
    import tvm
    import tvm.relay as relay
    from tvm.contrib import graph_runtime
    from IPython import embed

    num_layers = 3
    num_hyperparams = 4
    batch = 1
    hidden_size = 20

    def to_numpy(tensor):
        return tensor.detach().numpy()

    class My_LSTM(nn.Module):
        def __init__(self, num_hyperparams, hidden_size, num_layers):
            super(My_LSTM, self).__init__()
            self.lstm = nn.LSTM(input_size=num_hyperparams,
                                hidden_size=hidden_size,
                                num_layers=num_layers)

        def forward(self, input, hn, cn):
            output, (hn, cn) = self.lstm(input, (hn, cn))
        
            return output, hn, cn

    model = My_LSTM(num_hyperparams, hidden_size, num_layers)

    input = torch.randn(1, batch, num_hyperparams)
    h0 = torch.zeros(num_layers, batch, hidden_size)
    c0 = torch.zeros(num_layers, batch, hidden_size)

    loop_num = 5
    hn = h0
    cn = c0
    output_torch = []

    for i in range(loop_num):
        output, hn, cn = model(input, hn, cn)
        output_torch.append(to_numpy(output[0]))


    onnx_path = 'temp.onnx'
    dummy_input = torch.randn(1, batch, num_hyperparams)
    dummy_h0 = torch.randn(num_layers, batch, hidden_size)
    dummy_c0 = torch.randn(num_layers, batch, hidden_size)

    torch.onnx.export(model, 
            (dummy_input, dummy_h0, dummy_c0),
            onnx_path,
            input_names=['input','h0','c0'],
            output_names=['output', 'hn', 'cn'],
            )

    sess = ort.InferenceSession(onnx_path)
    input_names = [item.name for item in sess.get_inputs()]

    hn = to_numpy(h0)
    cn = to_numpy(c0)
    output_onnx = []

    for i in range(loop_num):
        output, hn, cn = sess.run(None,
            {input_names[0]:to_numpy(input),
             input_names[1]:hn,
             input_names[2]:cn})
        output_onnx.append(output[0])


    onnx_model = onnx.load(onnx_path)
    mod, params = relay.frontend.from_onnx(onnx_model)

    with relay.build_config(opt_level=0):
        graph, lib, params_run = relay.build(mod,
                                         target = "llvm",
                                         params=params)
                                         
    ctx = tvm.cpu()
    m = graph_runtime.create(graph, lib, ctx)
    m.set_input(**params_run)

    hn = to_numpy(h0)
    cn = to_numpy(c0)
    output_tvm = []
    for i in range(loop_num):
        m.set_input('input', to_numpy(input))
        m.set_input('h0', hn)
        m.set_input('c0', cn)
        m.run()
        output_tvm.append(m.get_output(0).asnumpy()[0])
        hn=m.get_output(1)
        cn=m.get_output(2)

    for i in range(2):
        print("loop = %d"%i)
        print("Pytorch     : %f"%output_torch[i][0,0])
        print("OnnxRuntime : %f"%output_onnx[i][0,0])
        print("TVM         : %f"%output_tvm[i][0,0])

When you run this code, you will get output similar to the following:

loop = 0
Pytorch     : -0.022901
OnnxRuntime : -0.022901
TVM         : -0.022901
loop = 1
Pytorch     : -0.027888
OnnxRuntime : -0.027888
TVM         : -0.016093

This result indicates that if the LSTM has a hidden state of 0, the TVM works, otherwise it does not work. How can I load the pytorch lstm model with tvm properly?

My environment is below:

  • OS: Ubuntu 18.04
  • Pytorch: 1.6.0
  • ONNX: 1.8.1
  • TVM: 0.8.dev0
  • onnxruntime: 1.7.0

Thank you.

Hi there,

I just compiled from source and the output is the same now:

loop = 0
Pytorch     : 0.048932
OnnxRuntime : 0.048932
TVM         : 0.048932
loop = 1
Pytorch     : 0.079360
OnnxRuntime : 0.079360
TVM         : 0.079360