[QNN] How to simplify QNN op prior to BYOC partitioning?

Hi, I’m trying to import pre-quantized ONNX and PyTorch models with TVM frontend and I noticed that some op patterns in the imported Relay module can be fused into a single qnn op.

For instance, a imported quantized ONNX ResNet-18 model:

def @main(%input: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = qnn.quantize(%input, 0.0186584f /* ty=float32 */, 114 /* ty=int32 */, out_dtype="uint8", axis=1) /* ty=Tensor[(1, 3, 224, 224), uint8] */;
  %1 = qnn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), uint8] */, 114 /* ty=int32 */, 128 /* ty=int32 */, 0.0186584f /* ty=float32 */, 0.00308922f /* ty=float32 */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7], out_dtype="int32") /* ty=Tensor[(1, 64, 112, 112), int32] */;
  %2 = nn.bias_add(%1, meta[relay.Constant][1] /* ty=Tensor[(64), int32] */) /* ty=Tensor[(1, 64, 112, 112), int32] */;
  %3 = qnn.requantize(%2, 5.764e-05f /* ty=float32 */, 0 /* ty=int32 */, 0.0281708f /* ty=float32 */, 0 /* ty=int32 */, axis=0, out_dtype="uint8") /* ty=Tensor[(1, 64, 112, 112), uint8] */;
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), uint8] */;
  %5 = qnn.conv2d(%4, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), uint8] */, 0 /* ty=int32 */, 128 /* ty=int32 */, 0.0281708f /* ty=float32 */, 0.00293729f /* ty=float32 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %6 = nn.bias_add(%5, meta[relay.Constant][3] /* ty=Tensor[(64), int32] */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %7 = qnn.requantize(%6, 8.27459e-05f /* ty=float32 */, 0 /* ty=int32 */, 0.0205264f /* ty=float32 */, 0 /* ty=int32 */, axis=0, out_dtype="uint8") /* ty=Tensor[(1, 64, 56, 56), uint8] */;
  %8 = qnn.conv2d(%7, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), uint8] */, 0 /* ty=int32 */, 128 /* ty=int32 */, 0.0205264f /* ty=float32 */, 0.00604534f /* ty=float32 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %9 = nn.bias_add(%8, meta[relay.Constant][5] /* ty=Tensor[(64), int32] */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %10 = qnn.requantize(%9, 0.000124089f /* ty=float32 */, 0 /* ty=int32 */, 0.0459817f /* ty=float32 */, 151 /* ty=int32 */, axis=0, out_dtype="uint8") /* ty=Tensor[(1, 64, 56, 56), uint8] */;
  %11 = qnn.dequantize(%10, 0.0459817f /* ty=float32 */, 151 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %12 = qnn.dequantize(%4, 0.0281708f /* ty=float32 */, 0 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %13 = add(%11, %12) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %14 = qnn.quantize(%13, 0.0278099f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="uint8") /* ty=Tensor[(1, 64, 56, 56), uint8] */;
...

where around %13 , the quant(add(dequant(a), dequant(b))) can be fused into a single qnn.add op.

Another example, a imported quantized PyTorch ResNet-18 model:

def @main(%input: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = qnn.quantize(%input, 0.018622f, 114, out_dtype="uint8", axis=1);
  %1 = nn.pad(%0, 114f, pad_width=[[0, 0], [0, 0], [3, 3], [3, 3]]);
  %2 = qnn.conv2d(%1, %conv1_weight, 114, 0, 0.018622f, 0.00308922f, strides=[2, 2], padding=[0, 0, 0, 0], channels=64, kernel_size=[7, 7], out_dtype="int32");
  %3 = nn.bias_add(%2, %conv1_bias);
  %4 = qnn.requantize(%3, 5.75275e-05f, 0, 0.0146432f, 0, axis=1, out_dtype="int32");
  %5 = clip(%4, a_min=0f, a_max=255f);
  %6 = cast(%5, dtype="uint8");
...

where requant: int32 -> clip(0, 255) -> cast: uint8 can be fused into a single qnn.requantize op.

Sure I can capture these patterns in the BYOC partitioning, but in that way I have to handle them in the C++ codegen. Is there an easy way to do such op fusions in python? I tried the DFPatternCallback and tvm.relay.dataflow_pattern.rewrite, but they cannot rewrite all matched pattern in a module.

This is not BYOC specific problem but a general problem for prequantized models. In other words, this is backend independent.

We usually call it “expression simplification” instead of fusion, because you removed unnecessary ops in this case instead of executing them together in a single kernel. For example:

%1 = dense(%data, %weight);
%2 = add(%1, 0);

This is simplification:

%1 = dense(%data, %weight);

And this is fusion:

%1 = fn(%data, %weight) {
  %a = dense(%data, %weight);
  add(%a, 0);
}
%2 = %1(%data, %weight);

Relay already has a pass for expression simplification and it seems to me that you can help improve this pass to cover the patterns: https://github.com/apache/tvm/blob/main/src/relay/transforms/simplify_expr.cc

I don’t understand what you meant by the pattern cannot rewrite all matched patterns in a module. It seems to me that pattern you specified is not comprehensive enough to cover all cases?

I also thought that pat match + rewrite would do this job. Do you mean the rewrite only happen on the first match?

Sorry, I’m not quite familiar with the utilities. The example only demonstrates how to rewrite a manually defined graph, but what I need is to find all matched patterns in a Relay module and replace them.

I took a look at the SimplifyExpr pass, it should do the job. All I need is to add these QNN-related patterns to it.

You can take a look at realistic examples of pat match and rewrite in https://github.com/apache/tvm/blob/2b57907d9ed9b6c31e368ef9b517e528dff3ae9e/python/tvm/relay/frontend/pytorch_utils.py. I wrote all of them, and I remember rewrite happened on all matched patterns.

Also the doc you linked says: " The callback function will be invoked recursively on the returned pattern until the pattern stops changing. As a result, if self.pattern matches any part of the graph that the callback returned, the rewriter will run in a loop. If you want to avoid multiple rewrites, you can pass a rewrite_once=True parameter to the constructor."

So pat match and rewrite is what you want to use.

1 Like

Thank you.

I also found the relay.transform.FakeQuantizationToInteger pass covers the case to simplify the deuantize -> fp32_op -> quantize pattern in module imported from ONNX.

As for the qnn.requantize:int32 -> clip(0, 255) -> cast: uint8 pattern in module imported from PyTorch, I found that it is defined in _do_bias_and_requantize in qnn_torch.py. So why do we need a clip and cast here? Why not setting the out_dtype attr of the qnn.requantize to uint8 and set the output_zero_point to 0 if there is ReLU?

hmm I don’t quite remember. I was definitely trying to follow what PyTorch does, except we do bias add in int32. https://github.com/pytorch/FBGEMM/blob/main/include/fbgemm/OutputProcessing-inl.h#L116-L121

If there is a relu, we need to clip with zero point as clip min. But qnn.requantize always clips with 0 as clip min. So I don’t get why doing single requantize with out_dtype = uint8 + setting the output_zero_point to 0 if there is ReLU would be equivalent to what we are doing. Let me know if there is non-obvious quantization math happening.

Oh, I just assumed that if there’s a ReLU, the output_zero_point will always be 0 when out_dtype == uint8, since in this case a quint8 value doesn’t need to represent any negative part. I believe this is any quantizers will do, though they usually just fuse the ReLU op before apply quantization.

Not sure about that. PyTorch uses uint8 exclusively, and it does have non-zero zero point. You can test the accuracy with or without your change on imagenet.

OK, I’ll investigate on this. Although from my observation, the zero point of immediate quint8 tensors after ReLU is always 0, at least in ResNet-18.

Yeah, if what you are claiming is true, that’s something I didn’t know about. But even if that’s the case, I believe it is important to follow what PyTorch does as much as possible by default. Optimization like this can always done by users after the model is imported if they know what they are doing.

You can use this repo https://github.com/Edgecortix-Inc/pytorch_quantization/tree/master/tvm_qnn_evaluation to quickly test accuracy on many torchvision models before and after converting to TVM.

1 Like