[PyTorch][QNN] Cannot import Quantized MobileBERT produced by FX graph mode quantization

Hi, I am trying to import a scripted quant-MobileBERT torch model via the relay.frontend.from_pytorch API, which works fine for quantized vision models like resnet50. However, I met the following problem (maybe due to the complexity of BERT model):

When it comes to

4171: input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)

inside the qnn_torch.add_input_quant_params_to_op_inputs() function,


490: if "quantized::conv" in operator or "quantized::linear" in operator:
        # This is required for quantizing the bias
        assert len(input_scales) == 1, "One quantized parameter expected for qconv or qlinear."
        input_scales_for_bias[node.inputsAt(1).debugName()] = input_scales[0].node().f("value")

it expects input_scales[0].node() to be something like:

%493 : float = prim::Constant[value=0.01865844801068306]()

(The case in quantized resnet50)

Where it can directly pull out the value of the scale since it is a constant.

However in the MobileBERT case input_scales[0].node() is actually something like:

%cat_output_scale_0.1 : Tensor = prim::GetAttr[name="cat_output_scale_0.1"](%self.1)

It is not a constant straight out of the box and needs to call the function to get the scale attribute from %self.1( AKA the script module). Thus a error will be thrown out since there’s actually no “value” key attached to input_scales[0].node().f here.

Actually, I tried to hack this by passing in the script_module and get the scale constant by hand via something like float(getattr(script_module,func_name)) and it does eliminates the errors here temporarily. However this is not a complete solution and I still get more errors from the following steps since some of the nodes cannot be recognized/parsed correctly.

Any suggestions for this problem? Thanks : )

@masahi @lhutton1 @comaniac

The best solution would be to “inline” such GetAttr, so that qparams area always directly accessible. I believe such transformation is not hard to do in FX.

1 Like

Hi masahi, thanks for the advice! I “inlined” all constants on the torch FX end and the problem is solved. However, I met some new errors when it comes to the new_mod = transform.InferType()(new_mod) process after converting a linear layer op.

The error says:

The Relay type checker is unable to show the following types match:
  Tensor[(384), int32]
  Tensor[(512), int32]
In particular:
  dimension 0 conflicts: 384 does not match 512.

Since it is a quantized BERT model and the shape of the input tensor of this Linear Layer is (1,384,384), and the weight shape is (384, 512), which is slightly different from what computer vision models’ linear input looks like. Will this messes up the default quantization operations and thus causing the error?

Our dense op expects the input shape to be transposed, i.e. (512, 384) for your case. You can add op.tranpose on your weight.

Ah sorry, there’s a typo here. The weight’s shape is already(512, 384) and I’m still getting this error. My guess is that the quantized dense operator is expecting a 2 dimension input tensor and a 2 dimension weight here. Since the input tensor shape in BERT case here is a 3 dimension vector, it can’t deal with it properly. I’m thinking of using the quantized batch_matmul operator here as an alternative and reshaping the weight tensor to (1,512,384) to match the input tensor’s (1, 384, 384). Do you think this is doable?

Like in non-quantized situation, there’s a adaptive mapping strategy:

def linear(self, inputs, input_types):
    # https://pytorch.org/docs/stable/nn.functional.html#linear
    # 0 - input
    # 1 - weight
    bias = inputs[2]
    a_shape = self.infer_shape_with_prelude(inputs[0])
    b_shape = self.infer_shape_with_prelude(inputs[1])
    if len(a_shape) == 2 and len(b_shape) == 2:
        mm_out = _op.nn.dense(inputs[0], inputs[1])
    elif len(b_shape) == 1:
        mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2])
        mm_out = self.matmul(
            [inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]
    if isinstance(bias, _expr.Expr):
        bias_ndims = len(self.infer_shape_with_prelude(bias))
        if bias_ndims == 1:
            return _op.nn.bias_add(mm_out, bias, axis=-1)
        mm_dtype = self.infer_type_with_prelude(mm_out).dtype
        return self.add([mm_out, bias], [mm_dtype, input_types[2]])
    return mm_out

I guess that’s what we are missing for quantized case in order to support BERT.

Do you think there’s a workaround for this case?

Given the error message above, I think it is complaining about the shape of the bias tensor.

I reshaped the input tensor’s shape from (1, 384, 384) to (384, 384) for qnn.dense to take it in properly, it seem working for now… Still need to check the final correctness of the whole model. However if I have to do this eventually then the batch size cannot be larger than 1

Just a quick update, I solved this problem.

It turns out that you are correct.The default requantize and bias_add axis are set to 1, where it should be axis=-1 or axis=2 in my case since the output of my Dense operator is 3 dimensional (1,384,512) and the bias, as well as requantize factors, should be applied on the last axis.

The problem is solved by changing the axis param of requantize and bias_add operator to -1.

Hi @mgeek, great to hear that you managed to get the quantized mobileBERT compiled by TVM. Could you please share your setup and scripts? I was searching for quantized Pytorch model compilation, and there are way more questions than answers. It’d help the community greatly if we can learn from your investigation. Thanks.

Sure, I guess I can write a post about it as soon as I got some time. However it’s not easy to put all changes in one script, since it involves some minor changes here and there…

If you are stuck somewhere as well when trying to import quantized mobileBERT, I’m more than glad to give it a look and see if there’s anything I can do to help.

Hi @mgeek thanks. I can try to go through this process with your help and summarize the learning in a post and share. I’ll start with DM first.