[Bug] [Frontend][Pytorch] Relay IR is inconsistent with that of the original model

Hi, I created a test pytorch quantized model, The structure is as follows:

There are 2 dequantize nodes which operate with different scale and zero_point, When I import using tvm the relay ir is as follows:

%0 = qnn.quantize(%input, 0.004226f, 10, out_dtype="uint8", axis=1);
  %1 = nn.pad(%0, pad_value=10f, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]);
  %2 = qnn.conv2d(%1, %backbone.conv1.conv_weight, 10, 0, 0.004226f, 0.0889414f, strides=[2, 2], padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3], out_dtype="int32");
  %3 = nn.bias_add(%2, %backbone.conv1.conv_bias);
  %4 = qnn.requantize(%3, 0.000375866f, 0, 0.0899963f, 0, axis=1, out_dtype="int32");
  %5 = clip(%4, a_min=0f, a_max=255f);
  %6 = cast(%5, dtype="uint8");
  %7 = nn.pad(%6, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]);
  %8 = qnn.conv2d(%7, %backbone.conv2.depth_wise.conv_weight, 0, 0, 0.0899963f, 0.0203802f, padding=[0, 0, 0, 0], groups=16, channels=16, kernel_size=[3, 3], out_dtype="int32");
  %9 = nn.bias_add(%8, %backbone.conv2.depth_wise.conv_bias);
  %10 = qnn.requantize(%9, 0.00183414f, 0, 0.123625f, 0, axis=1, out_dtype="int32");
  %11 = clip(%10, a_min=0f, a_max=255f);
  %12 = cast(%11, dtype="uint8");
  %13 = qnn.conv2d(%12, %backbone.conv2.point_wise.conv_weight, 0, 0, 0.123625f, 0.00762315f, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1], out_dtype="int32");
  %14 = nn.bias_add(%13, %backbone.conv2.point_wise.conv_bias);
  %15 = qnn.requantize(%14, 0.000942414f, 0, 0.0926806f, 0, axis=1, out_dtype="int32");
  %16 = clip(%15, a_min=0f, a_max=255f);
  %17 = cast(%16, dtype="uint8");
  %18 = (%6, %17);
  %19 = %18.0;
  %20 = qnn.dequantize(%19, 0.0899963f, 0);
  %21 = %18.1;
  %22 = qnn.dequantize(%21, 0.0899963f, 0);
  (%20, %22)

How are the scales of the two dequantize nodes in relay ir the same? Is this a bug?

1 Like

Interesting, this looks like a bug. Can you post a complete repro script?

my code:

import torch
from torch import nn
from torch.quantization import QuantStub, DeQuantStub
from tvm import relay
import numpy as np

class ConvBnRelu(nn.Module):
    def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1, bias=True, groups=1):
        super(ConvBnRelu, self).__init__()
        if groups > 1:
            self.conv = nn.Conv2d(inp, inp, kernel_size, stride, padding, bias=bias, groups=groups)
            self.bn = nn.BatchNorm2d(inp)
        else:
            self.conv = nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=bias, groups=groups)
            self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.relu(x)
        return x


def conv_bn(inp, oup, stride=1, width_multiplier=1):
    return ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)


def conv_dw(inp, oup, stride, width_multiplier=1, padding=1):
    dw_block = nn.Sequential()
    depth_wise = ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=padding, bias=False, groups=inp)
    point_wise = ConvBnRelu(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)

    dw_block.add_module('depth_wise', depth_wise)
    dw_block.add_module('point_wise', point_wise)

    return dw_block

class Backbone(nn.Module):
    def __init__(self, width_multiplier=1):
        super(Backbone, self).__init__()
        self.width_multiplier = width_multiplier
        self.conv1 = conv_bn(3, 16, 2, self.width_multiplier)
        self.conv2 = conv_dw(16, 32, 1, self.width_multiplier)
    
    def forward(self, inputs):
        x1 = self.conv1(inputs)
        x2 = self.conv2(x1)
        return [x1, x2]

class QuantizableBackbone(nn.Module):
    def __init__(self, inputsize=(128, 128)):
        super(QuantizableBackbone, self).__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.backbone = Backbone()

    def fuse_model(self):
        for idx, m in enumerate(self.modules()):
            if type(m) == ConvBnRelu:
                torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)

    def forward(self, input):
        input = self.quant(input)
        y0, y1 = self.backbone(input)
        y0 = self.dequant(y0)
        y1 = self.dequant(y1)
        return y0, y1

from torch.quantization import prepare_qat, get_default_qat_qconfig, convert
fp32_input = torch.randn(1, 3, 128, 128)
model = QuantizableBackbone()
model.train()
model.fuse_model()
BACKEND = "qnnpack"
model.qconfig = get_default_qat_qconfig(BACKEND)

prepare_qat(model, inplace=True)

model.eval()
y = model(fp32_input)
model_int8 = convert(model, inplace=True)

torch.jit.save(torch.jit.trace(model_int8, fp32_input), "quantized_model.pt")

input_name = "input"
input_infos = [(input_name, ((1, 3, 128, 128), 'float32'))]
img_input = np.fromfile("test_image.raw", dtype=np.float32).reshape((1, 3, 128, 128))

pt_input = torch.from_numpy(img_input)
model = torch.jit.load("quantized_model.pt")
script_module = torch.jit.trace(model, pt_input).eval()

mod, params = relay.frontend.from_pytorch(script_module, input_infos)
print(mod)

I guess it’s caused by a bug in the _get_quant_param_for_input function in qnn_torch.py.

7 defined in (%7 : Tensor, %8 : Tensor = prim::TupleUnpack(%47)
8 defined in (%7 : Tensor, %8 : Tensor = prim::TupleUnpack(%47)

Here 7 and 8 correspond to two dequantize nodes, scale and zero_point are obtained by recursion %47.

Ok I understand the problem. This is due to

When we reach the second dequantize below, we need to find its input scale and zp. I do this by DFS on input until we meet a quantized op that produces a new quantized tensor.

  %47 : (Tensor, Tensor) = prim::TupleConstruct(%input.2, %Xq.1)
  %7 : Tensor, %8 : Tensor = prim::TupleUnpack(%47)
  %48 : Tensor = aten::dequantize(%7) # /home/masa/anaconda3/envs/torch-1.10/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:84:0
  %49 : Tensor = aten::dequantize(%8)

The issue is that I assumed quantized tensor comes earlier in the input list, while we need to visit the second input of TupleConstruct. So we end up doing DFS on %input.2 which finds the wrong input qparam.

When I implemented that, I was not sure this would cause a problem in practice. Your example is the first instance that this strategy fails.

A better solution is to propagate qparam information forward, so that backward lookup won’t be necessary. For example, in

we can return (_expr.const(inputs[1]), _expr.const(inputs[2])) along with the original relay output. This is the most elegant solution (also close to what PyTorch does), but it requires many change in the existing converter functions to retrieve the relay input.

An easier, but ugly way would be to record output scale and zp in a global dictionary after each conversion. Then we can look up input scale and zp from anywhere.

Does this make sense? If possible, can you send a PR to fix this problem (in a whatever way). I can do it, but since I always find it difficult to get someone to review my PyTorch frontend change, it might take a while for the change to land. However, if you send a fix, I can review and merge right away.

1 Like

sorry for the late reply! Thanks for the info, let me think about how to implement it. By the way, why is there a combination of prim::TupleConstruct and prim::TupleUnpack in pytorch’s graph? its presence seems redundant, Why not just use %Xq.2 and %Xq.1 directly?

%68 : (Tensor, Tensor) = prim::TupleConstruct(%Xq.2, %Xq.1)
%7 : Tensor, %8 : Tensor = prim::TupleUnpack(%68)

Yeah it’s strange, but I guess if the tuple creation and unpacking are done in separate parts of the original Python model, then I can imagine they end up with like that.

There is probably a PyTorch JIT pass that removes such tuple pack/unpack. We can run it from our frontend, that would be the easiest solution.

Thank you so much, let me look into it.

As you said, this problem is solved by pass in the frontend.

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 217645d81..fb0a72af0 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3520,6 +3520,7 @@ def _run_jit_passes(graph):
         torch._C._jit_pass_onnx_function_substitution(graph)
     else:
         torch._C._jit_pass_inline(graph)
+    torch._C._jit_pass_lower_all_tuples(graph)
 
 
 def _get_tensor_and_var(torch_tensor, name):
@@ -3976,7 +3977,7 @@ def from_pytorch(
         qnn_torch.add_quant_params(tvm_params, weight_quant_params)
         converter.update_convert_map(qnn_torch.convert_map)
 
-    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
+    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
     if isinstance(ret, list):
         # ListConstruct kept original python list. Convert to tuple.
         ret = _expr.Tuple(ret)
1 Like

Great! I made the following diff that adds torch._C._jit_pass_lower_all_tuples and pass all of our PT tests (the return value properly handled). I also added your example to our quantized PT model test cases.

Can you open a PR about it? We can merge it today.

PR has been created, please review.

1 Like