[QNN][PyTorch] QAT Performance Drop Off

Dear community, lately i’ve played around with QAT on the PyTorch level. My model was a custom CNN/MLP model for image classification, containing only the following layers:

  • Conv2D
  • MaxPool2D
  • Linear
  • Dropout (for training only obv.)
  • QuantStub/Dequantstub

Without quantization the performance was around 92%. Using quantization-aware-training (following PyTorch’s guide: https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#quantization-aware-training) the performance was still on a high level (around 91.6%) in PyTorch. Now following TVM’s guide for deploying a prequantized model (https://tvm.apache.org/docs/tutorials/frontend/deploy_prequantized.html). the measured accuracy on the TVM level (for the whole test-set for my QAT-model) dropped down to 60%. I was training on PyTorch 1.5.1, deploying on x86, quantizing through QNN on the TVM level (which i could see in the Relay IR).

Any suggestions how this could happen?

Trying to dive further into the accuracy drop off i observed a strange behaviour for TVM’s accuracy: TVM Discuss 1

In the first column you see the amount of epochs that i trained the model. The next two columns represent the measured accuracy of the QAT-trained model on the PyTorch and TVM level. The last column represents the %-delta between the two accuracy values. Apparently the accuracy on the TVM level decreases the longer i train my model (while accuracy on the PyTorch level obviously increases). Another interesting thing i observed was the accuracy differences when freezing quantization parameters earlier in the training process: the earlier i freeze the quantization parameters, the higher the accuracy is in TVM (freezing after 1 epoch: 77% accuracy, freezing after 8 epochs (from 20): 60%).

Im sorry if i missed any necessary information, its my first blog post on this forum! I will happily add any requested information or even scripts to reproduce.

Do you have any idea how…

  1. The drop-off from around 90% (in PyTorch) to 60% (in TVM) can be explained?
  2. The decreasing TVM accuracy together with the later freezing of the quantization parameters can be explained?

Best regards, Knight3

When you convert quantized pytorch models to TVM, you should do quantization in exactly the same way as the PyTorch tutorial shows. If you are comparing the accuracy of models quantized with PyTorch-tutorial way vs TVM-tutorial way, you are not comparing the same models. The dummy quantization procedure in the TVM tutorial is not relevant for you.

If you are correctly doing what I said above, then I don’t know why accuracy would drop after conversion to TVM. It’s better to make sure post-training quantization works before QAT.

1 Like

Thank you masahi for your reply. In the TVM tutorial they just use dummy PTQ on a pretrained mobilenet, i didnt use that. I literally just imported the QAT trained model as a traced Torchscript in TVM and tried to convert it to Relay. In my case i used QAT the way PyTorch describes it in their Quant tutorial. I will do PTQ tests today and give you feedback on my results

First test using PTQ, training for 10 epochs:

  • PyTorch Accuracy: 90.19%
  • TVM Accuracy: 67.35%

So the results are similiar (while not being as bad as with QAT). If you have no immediate suggestion i could put together the scripts to enable you to reproduce. Possibly i did an error thats independent of TVM…

hmm, I’ve only tested on quantized models from torchvision, but for post training quantization accuracy shouldn’t drop like that. See https://github.com/Edgecortix-Inc/pytorch_quantization/tree/master/tvm_qnn_evaluation

Yeah, if you give me a repro script, I can have a look.

Here i’ve quickly created a repo: https://github.com/Knight3-code/PyTorch-TVM-PTQ-Test

Hopefully its just an error on the code level, than we could fix it easily!

I see, you are serializing and loading quantized model. As I reported in https://github.com/pytorch/pytorch/issues/39690, there is an typing problem when Torchscript module is serialized and loaded back.

I thought I fixed this problem with the PR https://github.com/apache/incubator-tvm/pull/5839, but there could be some other issues that require further workaround.

Can you try quantizing and compiling with TVM in the same process (i.e., without serialize/deserialize)?

1 Like

Thank you for you fast reply. Youre suggestion was exactly correct: Not serializing/deserializing actually works well. I got the following results.

PTQ (10 Epochs): Pytorch Accuracy = 89.5%, TVM Accuracy = 90.05%

PTQ (20 Epochs): Pytorch Accuracy = 91.83%, TVM Accuracy = 91.59%

QAT (10 Epochs): Pytorch Accuracy = 89.97%, TVM Accuracy = 89.91%

QAT(20 Epochs): Pytorch Accuracy = 91.27%, TVM Accuracy = 91.01%

So it seems that the deserialization in TVM causes problems, right? Sadly this workflow is probably not applicable for my usecase, because of the serialization restriction.

Thank you for you efforts.

Best regards, Knight3

No this is not a TVM’s problem. As soon as you torch.jit.save, typing information is lost. All we can do is either wait for PyTorch people to fix this problem, or do workaround ourselves.

Below is the workaround that I thought would fixed this problem for us. Can you check if your TVM install has this commit?

2 Likes

Your fix works for me. Didn’t had this fix earlier, just pulled and rebuilt TVM. Now i get the desired accuracy from serialized models. Thanks for your help.