I am try to compile the DLRM from PyTorch, when i call the function relay.frontend.from_pytorch get some errors. Here are some of the information I’ve summarized: The part of source code from DLRM NET:
# concatenate dense and sparse features
(batch_size, d) = x.shape
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
# perform a dot product
Z = torch.bmm(T, torch.transpose(T, 1, 2))
# append dense feature with the interactions (into a row vector)
# approach 1: all
# Zflat = Z.view((batch_size, -1))
# approach 2: unique
_, ni, nj = Z.shape
# approach 1: tril_indices
# offset = 0 if self.arch_interaction_itself else -1
# li, lj = torch.tril_indices(ni, nj, offset=offset)
# approach 2: custom
offset = 1 if self.arch_interaction_itself else 0
li = torch.tensor([i for i in range(ni) for j in range(i + offset)])
lj = torch.tensor([j for i in range(nj) for j in range(i + offset)])
Zflat = Z[:, li, lj]
# concatenate dense features and interactions
R = torch.cat([x] + [Zflat], dim=1)
Zflat is obtained by slicing and indexing Z, when I call the torch.jit.trace to load model, the part of scripted model’ graph as follow:
%40 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%41 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%42 : int = prim::Constant[value=9223372036854775807]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%43 : int = prim::Constant[value=1]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%44 : Float(128, 27, 27, strides=[729, 27, 1], requires_grad=1, device=cpu) = aten::slice(%Z, %40, %41, %42, %43) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%45 : int = prim::Constant[value=4]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%46 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%47 : Device = prim::Constant[value="cpu"]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%48 : None = prim::Constant()
%49 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%50 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%51 : None = prim::Constant()
%li : Long(351, strides=[1], requires_grad=0, device=cpu) = aten::to(%li.1, %45, %46, %47, %48, %49, %50, %51) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%53 : int = prim::Constant[value=4]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%54 : int = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%55 : Device = prim::Constant[value="cpu"]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%56 : None = prim::Constant()
%57 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%58 : bool = prim::Constant[value=0]() # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%59 : None = prim::Constant()
%lj : Long(351, strides=[1], requires_grad=0, device=cpu) = aten::to(%lj.1, %53, %54, %55, %56, %57, %58, %59) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%61 : None = prim::Constant()
%62 : Tensor?[] = prim::ListConstruct(%61, %li, %lj)
%Zflat : Float(128, 351, strides=[351, 1], requires_grad=1, device=cpu) = aten::index(%44, %62) # ./code/git/trunk/tensorxt/tests/graph/dlrm/dlrm_s_pytorch.py:493:0
%Zflat is computed by aten::index, the %44 is the output from aten::slice, the %62 should be the params for aten::index and it’s a list, %61 is nonetype, I think it means iterator.
TVM deal with nonetype constant:
exception:
Exception: warning unhandled case: <class ‘NoneType’>
reproduce the problem: unzip the scripted_model.zip,and execute the script:
from tvm import relay
import torch
scripted_model = torch.jit.load('scripted_model.pt')
relay_mod, params = relay.frontend.from_pytorch(scripted_model, [("dense_x", (128, 13)), ("ly", (128, 416))])