Hi, I am struggling for several hours with the following issue:
I’ve got a model in pytorch that I want to convert to TVM. Here is the code of the model. It includes two GRU and two FC layers. From what I’ve found until now, TVM does not support yet RNN operators if converting from pytorch directly. Therefore I’ve decided to convert my model first to ONNX and then tried to convert it to TVM (where both GRU and LSTM operators seem to be supported), but got an error that I’m struggling with.
Here is the code of the model in pytorch.
import onnx
import torch
from torch import nn
import tvm
import pathlib
from tvm import relay
class MyModel(nn.Module):
def __init__(self, architecture, total_dim, GRU1_dim, GRU2_dim, FC1_dim):
FC2_dim = total_dim * 4
super(MyModel, self).__init__()
self.architecture = architecture
self.total_dim = total_dim
if architecture == 'GRU':
self.GRU1 = nn.GRU(input_size=total_dim,
hidden_size=GRU1_dim,
batch_first=True)
self.GRU2 = nn.GRU(input_size=GRU1_dim,
hidden_size=GRU2_dim,
batch_first=True)
elif architecture == 'LSTM':
self.LSTM1 = nn.LSTM(input_size=total_dim,
hidden_size=GRU1_dim,
batch_first=True)
self.LSTM2 = nn.LSTM(input_size=GRU1_dim,
hidden_size=GRU2_dim,
batch_first=True)
self.fc1 = nn.Linear(in_features=GRU2_dim,
out_features=FC1_dim)
self.fc2 = nn.Linear(in_features=FC1_dim,
out_features=FC2_dim)
def forward(self, inp):
if self.architecture == 'GRU':
x, _ = self.GRU1(inp)
x, _ = self.GRU2(x)
elif self.architecture == 'LSTM':
x, _ = self.LSTM1(inp)
x, _ = self.LSTM2(x)
x = x[:, -1, :]
x = torch.exp(self.fc1(x))
x = torch.exp(self.fc2(x))
return x.view(inp.shape[0], 4, self.total_dim)
As you see, it is pretty straight forward. I’m converting it to the ONNX format first:
torch_model = MyModel("GRU", 14, 10, 10, 10)
input_data = torch.randn((1, 10, 14), requires_grad=True)
example_outputs = torch_model(input_data)
torch.onnx.export(torch_model,
input_data,
"model.onnx",
example_outputs = example_outputs
)
then I’m trying to convert it to TVM as showed in the tutorial:
onnx_model = onnx.load(pathlib.Path.cwd().joinpath("model.onnx"))
input_data = input_data.detach().numpy()
input_name = "1"
shape_dict = {input_name: input_data.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with tvm.transform.PassContext(opt_level=1):
intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target)
intrp.evaluate()(tvm.nd.array(input_data.astype(dtype)), **params).asnumpy()
The error I am getting when calling intrp.evalutate()
is the following:
ValueError: ('Graph Runtime only supports static graphs, got output type', TensorType([?, ?, ?], float32))
I’ve beeing trying to debug it but since I am not really good at C++, I found myself to be stuck with that.
How does output can be dynamic if I see ONNX graph in Netron and it shows all the dimensions? Can anyone elaborate on that? Does the problem occur already at ONNX converting stage or after that?