Error when importing pytorch quantized model

Hi,

I’m trying to train a quantization aware model with pytorch (1.7.0) and optimize it with autotvm (tvm vers 0.8 latest). My model is quite simple, following a VGG like structure. I build the model with pytorch and export it with jit script:

model = torch.quantization.convert(model)
m = torch.jit.script(model)
torch.jit.save(m, "sound" + args.model + ".jit.pt")

I have two errors:

  1. The following operators are not implemented: [‘prim::unchecked_cast’, ‘aten::is’]

These two operators are used by the pytorch maxpool2d operator and are exported in the graph as such:

%stride0.3 : int[] = prim::If(%67) # /home/thales/Documents/pytorchvenv36/lib/python3.6/site-packages/torch/nn/functional.py:536:4
    block0():
      %stride0.4 : int[] = prim::ListConstruct()
      -> (%stride0.4)
    block1():
      %stride0.5 : int[] = prim::unchecked_cast(%63)
      -> (%stride0.5)
  %input2.1 : Tensor = aten::max_pool2d(%input1.2, %62, %stride0.3, %64, %65, %58) # /home/thales/Documents/pytorchvenv36/lib/python3.6/site-packages/torch/nn/functional.py:538:11

I managed to work around this by giving relay.frontend.from_pytorch a custom map with

custom_map =  { "prim::unchecked_cast": relay.frontend.pytorch.PyTorchOpConverter.identity,
                "aten::__is__": relay.frontend.pytorch.PyTorchOpConverter.is_floating_point}
  1. Now I have a problem with an input tensor that is not recognized:

    %4 : torch.torch.nn.quantized.modules.Quantize = prim::GetAttrname=“quant” %19 : int = prim::GetAttrname=“dtype” %xq.1 : Tensor = aten::quantize_per_tensor(%x.1, %16, %18, %19) # /home/thales/Documents/pytorchvenv36/lib/python3.6/site-packages/torch/nn/quantized/modules/init.py:42:15

Basically the torchscript module seems to want the .quant parameter of the pytorch module and more specifically its datatype. However, in def _get_operator_nodes(nodes): the output of nodes that are of type “prim::GetAttr” are ignored. Thus the %xq.1 operator has an unknown input %19 and an error is raised in def _get_op_inputs(op_node, outputs).

Do you have any clues on this? Could it be the model that is not well defined?