[Frontend][Pytorch FX] Can't load quantized linear layer to Relay

When i load Pytorch fx quantized model to TVM like below code:

import torch
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qat_qconfig_mapping, quantize_fx
import tvm
from tvm import relay

qconfig_mapping = get_default_qconfig_mapping()
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(256, 256)
    def forward(self, input):
        x = self.linear(input)
        return x

mm = MyModule()

input = torch.randn((1,900,256))
mm_prepared = quantize_fx.prepare_fx(mm, qconfig_mapping, (input))
r = mm_prepared(input)
mm_quantize = quantize_fx.convert_fx(mm_prepared)


script_mm = torch.jit.trace(mm_quantize, (input))
input_shapes_mm = [('input', tuple(input.shape))]
mod, params = relay.frontend.from_pytorch(script_mm, input_shapes_mm)

this give:

The Relay type checker is unable to show the following types match:
  Tensor[(900), int32]
  Tensor[(256), int32]
In particular:
  dimension 0 conflicts: 900 does not match 256.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(256), int32]` does not match `Tensor[(900), int32]`
The Relay type checker is unable to show the following types match:
  Tensor[(900), float32]
  Tensor[(256), float32]
In particular:
  dimension 0 conflicts: 900 does not match 256.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(256), float32]` does not match `Tensor[(900), float32]`
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 4649, in from_pytorch
    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 4025, in convert_operators
    self.record_output_type(relay_out)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 220, in record_output_type
    self.infer_type_with_prelude(output)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 168, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/relay/frontend/pytorch.py", line 161, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/opt/conda/envs/python3_9/lib/python3.9/site-packages/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm.error.DiagnosticError: Traceback (most recent call last):
  6: TVMFuncCall
  5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  4: tvm::transform::Pass::operator()(tvm::IRModule) const
  3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::DiagnosticContext::Render()
  File "/workspace/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

environments:

>>> torch.__version__
'2.0.0+cu117'
>>> tvm.__version__
'0.11.1'

Wondering is there something obvious that I should fix? Thanks!

I found that the reshape operator may be missing in the quantized model’s IR, quantized model’s IR:

free_var %input: Tensor[(1, 900, 256), float32];
qnn.quantize(%input, 0.0624973f, 62, out_dtype="uint8", axis=1)
relay_out: free_var %input: Tensor[(1, 900, 256), float32];
free_var %linear._packed_params_weight: Tensor[(256, 256), float32];
%0 = qnn.quantize(%input, 0.0624973f, 62, out_dtype="uint8", axis=1);
%1 = qnn.quantize(%linear._packed_params_weight, meta[relay.Constant][0], 0, out_dtype="int8", axis=0);
free_var %linear._packed_params_bias: Tensor[(256), float32];
%2 = qnn.dense(%0, %1, 62, 0, 0.0624973f, meta[relay.Constant][0], units=256, out_dtype="int32");
%3 = qnn.quantize(%linear._packed_params_bias, meta[relay.Constant][1], 0, out_dtype="int32", axis=0);
%4 = nn.bias_add(%2, %3);
%5 = qnn.requantize(%4, meta[relay.Constant][2], 0, 0.0361197f, 66, axis=1, out_dtype="int32");
%6 = clip(%5, a_min=0f, a_max=255f);
cast(%6, dtype="uint8")

origin fp32 model’s IR

relay_out: free_var %input: Tensor[(1, 900, 256), float32];
free_var %linear.weight: Tensor[(256, 256), float32];
%0 = transpose(%linear.weight, axes=[1, 0]);
%1 = reshape(%input, newshape=[-1, 256]);
%2 = transpose(%0, axes=[1, 0]);
%3 = nn.dense(%1, %2, units=None);
%4 = reshape(%3, newshape=[1, 900, 256]);
free_var %linear.bias: Tensor[(256), float32];
nn.bias_add(%4, %linear.bias, axis=-1)