[PyTorch][QNN] Cannot import TorchScript produced by FX graph mode quantization

Hi, I tried to use the torch.fx based method to run static post-training quantization on a ResNet-18 model, then use torch.jit.trace to produce a TorchScript model. However, I was unable to import the TorchScript model into TVM using the PyTorch frontend, here’s the error:

Traceback (most recent call last):
...
  File ".../test_pytorch_quant.py", line 28, in <module>
    mod, params = relay.frontend.from_pytorch(
  File ".../tvm/python/tvm/relay/frontend/pytorch.py", line 3963, in from_pytorch
    input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
  File ".../tvm/relay/frontend/qnn_torch.py", line 493, in add_input_quant_params_to_op_inputs
    input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")
RuntimeError: required keyword attribute 'value' is undefined

This is my test code:

import torch
from torchvision.models.resnet import resnet18
from torch.quantization import quantize_fx

import tvm
from tvm import relay


if __name__ == "__main__":
    model_fp = resnet18(pretrained=True).eval()
    qconfig = torch.quantization.get_default_qconfig("qnnpack")
    qconfig_dict = {"": qconfig}

    model_fp = quantize_fx.fuse_fx(model_fp)
    model_fp = quantize_fx.prepare_fx(model_fp, qconfig_dict)

    # Calibrate.
    x = torch.rand((1, 3, 224, 224))
    _ = model_fp(x)

    # Quantize.
    model_quant = quantize_fx.convert_fx(model_fp)

    # Get TorchScript model.
    model_traced = torch.jit.trace(model_quant, x).eval()

    # Import to TVM.
    mod, params = relay.frontend.from_pytorch(
        script_module=model_traced,
        input_infos=[("x", x.shape)],
        keep_quantized_weight=True
    )

    mod: tvm.IRModule = mod
    print("[RELAY]")
    print(mod.astext(show_meta_data=False))

I’ve never attempted converting a fp32 model originally coming from FX, let alone quantized ones. I need to spend non-trivial time if FX support is desired.

What is your experience with FX-based quantization like? If that’s significantly better than the older eager-mode workflow, it may be worth supporting

I think FX-based quantization is much easier than eager mode since you don’t need to manually do the conv+bn+relu op fusion or insert quant/dequant stubs, as shown in my test code. You can basically take any fp model and do quantization without modifying it. (as long as the fp model is symbolic traceable by FX) Also, it provides much more flexibility. (you can specify different qconfigs for different submodules using just hierarchical identifier strings, very useful for mixed-precision inference)

Yeah I understand that in theory - I wonder what’s the actual user experience is like (robustness, accuracy etc). The old approach was painful but it worked when done properly.

For example, I don’t expect to be able to quantize MaskRCNN from torchvision using FX out of the box, because FX doesn’t support control flow.

Sorry I cannot answer this question since I’m just doing a little experiment here to see if this quantization-and-deployment flow can work or not. Maybe more user feedbacks are needed.

FYI I started experimenting with FX-based quantization with this commit: https://github.com/masahi/torchscript-to-tvm/commit/e0a376ce3a2b2d0fdaf51759f0dbf4abecddb233

For now my interest is in collecting good models for int8 perf benchmarking, so I don’t care about accuracy / calibration. I’m seeing good-looking graphs from FX-based Q: It works out-of-the-box for deeplab v3, works for yolov5 after some tweaks etc. If I don’t do calibration, I only need to add a few lines to get quantized models.

Ok I’ve got import of resnet18 and deeplabv3 quantized by FX working in the branch https://github.com/apache/tvm/compare/main...masahi:fx-quant?expand=1. It seems not too much effort is required to support FX. I’ll send a PR soon after doing more testing.

1 Like

Thanks for your efforts!

PR https://github.com/apache/tvm/pull/10091

Hi @Nullko. Did you manage to fix this issue? I am facing the same issue.

Hi @masahi Can you please share your script for torchfx+int8 compilation with TVM?

It’s here https://github.com/apache/tvm/blob/95aac9224eb5ef30ab5bb67471c1b6ddfd6e1d6e/tests/python/frontend/pytorch/test_fx_quant.py

Still getting the same error

RuntimeError: required keyword attribute ‘value’ is undefined

RuntimeError                              Traceback (most recent call last)
Input In [21], in <cell line: 1>()
----> 1 test_imagenet()

Input In [20], in test_imagenet()
     66 def test_imagenet():
     67     for model_func in [resnet50, efficientnet_b4]:
---> 68         quantize_and_build(model_func(pretrained=True).eval(), 224)

Input In [20], in quantize_and_build(model, in_size)
     22 with torch.no_grad():
     23     script_module = torch.jit.trace(qmodel, inp)
---> 24     mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)])
     25     mod = relay.transform.InferType()(mod)
     27     # Make sure that the model is quantized

File ~/.local/lib/python3.8/site-packages/tvm-0.8.0-py3.8-linux-x86_64.egg/tvm/relay/frontend/pytorch.py:3882, in from_pytorch(script_module, input_infos, custom_convert_map, default_dtype, use_parser_friendly_name, keep_quantized_weight)
   3878 if len(quantized_ops.intersection(set(op_names))) > 0:
   3879     weight_quant_params = qnn_torch.get_weight_quant_params(
   3880         script_module, packed_param_map.values()
   3881     )
-> 3882     input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
   3883     qnn_torch.add_quant_params_to_outputs(
   3884         outputs,
   3885         packed_param_map,
   (...)
   3888         keep_quantized_weight,
   3889     )
   3890     qnn_torch.add_quant_params(tvm_params, weight_quant_params)

File ~/.local/lib/python3.8/site-packages/tvm-0.8.0-py3.8-linux-x86_64.egg/tvm/relay/frontend/qnn_torch.py:493, in add_input_quant_params_to_op_inputs(graph)
    489         node.addInput(zp)
    491     if "conv" in operator or "linear" in operator:
    492         # This is required for quantizing the bias
--> 493         input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")
    495 return input_scales_for_bias

RuntimeError: required keyword attribute 'value' is undefined

The test runs on CI so it shouldn’t result in an error. It looks like your TVM install is old (v0.8), please try the latest version, main.

1 Like

Thanks @masahi. I have run it on the newer version and the issues are fixed. But compiling int8 is way slower than vanilla Pytorch Resnet50. And also, looks like the tensors are being converted back to float32 while compiling.

Maybe using keep_quantized_weight=True in from_torch helps. See below

The tracer warning that I am getting is because of the below step

script_module = torch.jit.trace(qmodel, inp)

But looks like Relay is not able to take the quantized model. I did set keep_quantized_weight=True, but still the inference time is >2x slower than vanilla Pytorch FP32.

How are your results looking? For int8+compiled model vs int8 vs FP32

It depends on a target. For x86, unless you use avx512, int8 is not expected to be any faster than fp32.