[relay][frontend]Import DLRM NET graph failed

I am try to compile the DLRM from PyTorch, when i call the function relay.frontend.from_pytorch get some errors. Here are some of the information I’ve summarized: The part of source code from DLRM NET:

            # concatenate dense and sparse features
            (batch_size, d) = x.shape
            T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
            # perform a dot product
            Z = torch.bmm(T, torch.transpose(T, 1, 2))
            # append dense feature with the interactions (into a row vector)
            # approach 1: all
            # Zflat = Z.view((batch_size, -1))
            # approach 2: unique
            _, ni, nj = Z.shape
            # approach 1: tril_indices
            # offset = 0 if self.arch_interaction_itself else -1
            # li, lj = torch.tril_indices(ni, nj, offset=offset)
            # approach 2: custom
            offset = 1 if self.arch_interaction_itself else 0
            li = torch.tensor([i for i in range(ni) for j in range(i + offset)])
            lj = torch.tensor([j for i in range(nj) for j in range(i + offset)])
            Zflat = Z[:, li, lj]
            # concatenate dense features and interactions
            R = torch.cat([x] + [Zflat], dim=1)

Zflat is obtained by slicing and indexing Z, when I call the torch.jit.trace to load model, the part of scripted model’ graph as follow:

  %40 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %41 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %42 : int = prim::Constant[value=9223372036854775807]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %43 : int = prim::Constant[value=1]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %44 : Float(128, 27, 27, strides=[729, 27, 1], requires_grad=1, device=cpu) = aten::slice(%Z, %40, %41, %42, %43) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %45 : int = prim::Constant[value=4]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %46 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %47 : Device = prim::Constant[value="cpu"]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %48 : None = prim::Constant()
  %49 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %50 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %51 : None = prim::Constant()
  %li : Long(351, strides=[1], requires_grad=0, device=cpu) = aten::to(%li.1, %45, %46, %47, %48, %49, %50, %51) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %53 : int = prim::Constant[value=4]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %54 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %55 : Device = prim::Constant[value="cpu"]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %56 : None = prim::Constant()
  %57 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %58 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %59 : None = prim::Constant()
  %lj : Long(351, strides=[1], requires_grad=0, device=cpu) = aten::to(%lj.1, %53, %54, %55, %56, %57, %58, %59) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
  %61 : None = prim::Constant()
  %62 : Tensor?[] = prim::ListConstruct(%61, %li, %lj)
  %Zflat : Float(128, 351, strides=[351, 1], requires_grad=1, device=cpu) = aten::index(%44, %62) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0

%Zflat is computed by aten::index, the %44 is the output from aten::slice, the %62 should be the params for aten::index and it’s a list, %61 is nonetype, I think it means iterator.

TVM deal with nonetype constant:

exception:

Exception: warning unhandled case: <class ‘NoneType’>

reproduce the problem: unzip the scripted_model.zip,and execute the script:

from tvm import relay
import torch


scripted_model = torch.jit.load('scripted_model.pt')
relay_mod, params = relay.frontend.from_pytorch(scripted_model, [("dense_x", (128, 13)), ("ly", (128, 416))])

https://github.com/apache/tvm/files/6737890/scripted_model.zip