Hi, I tried to use the torch.fx based method to run static post-training quantization on a ResNet-18 model, then use torch.jit.trace to produce a TorchScript model. However, I was unable to import the TorchScript model into TVM using the PyTorch frontend, here’s the error:
Traceback (most recent call last):
...
File ".../test_pytorch_quant.py", line 28, in <module>
mod, params = relay.frontend.from_pytorch(
File ".../tvm/python/tvm/relay/frontend/pytorch.py", line 3963, in from_pytorch
input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
File ".../tvm/relay/frontend/qnn_torch.py", line 493, in add_input_quant_params_to_op_inputs
input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")
RuntimeError: required keyword attribute 'value' is undefined
This is my test code:
import torch
from torchvision.models.resnet import resnet18
from torch.quantization import quantize_fx
import tvm
from tvm import relay
if __name__ == "__main__":
model_fp = resnet18(pretrained=True).eval()
qconfig = torch.quantization.get_default_qconfig("qnnpack")
qconfig_dict = {"": qconfig}
model_fp = quantize_fx.fuse_fx(model_fp)
model_fp = quantize_fx.prepare_fx(model_fp, qconfig_dict)
# Calibrate.
x = torch.rand((1, 3, 224, 224))
_ = model_fp(x)
# Quantize.
model_quant = quantize_fx.convert_fx(model_fp)
# Get TorchScript model.
model_traced = torch.jit.trace(model_quant, x).eval()
# Import to TVM.
mod, params = relay.frontend.from_pytorch(
script_module=model_traced,
input_infos=[("x", x.shape)],
keep_quantized_weight=True
)
mod: tvm.IRModule = mod
print("[RELAY]")
print(mod.astext(show_meta_data=False))
I’ve never attempted converting a fp32 model originally coming from FX, let alone quantized ones. I need to spend non-trivial time if FX support is desired.
What is your experience with FX-based quantization like? If that’s significantly better than the older eager-mode workflow, it may be worth supporting
I think FX-based quantization is much easier than eager mode since you don’t need to manually do the conv+bn+relu op fusion or insert quant/dequant stubs, as shown in my test code. You can basically take any fp model and do quantization without modifying it. (as long as the fp model is symbolic traceable by FX) Also, it provides much more flexibility. (you can specify different qconfigs for different submodules using just hierarchical identifier strings, very useful for mixed-precision inference)
Yeah I understand that in theory - I wonder what’s the actual user experience is like (robustness, accuracy etc). The old approach was painful but it worked when done properly.
For example, I don’t expect to be able to quantize MaskRCNN from torchvision using FX out of the box, because FX doesn’t support control flow.
Sorry I cannot answer this question since I’m just doing a little experiment here to see if this quantization-and-deployment flow can work or not. Maybe more user feedbacks are needed.
For now my interest is in collecting good models for int8 perf benchmarking, so I don’t care about accuracy / calibration. I’m seeing good-looking graphs from FX-based Q: It works out-of-the-box for deeplab v3, works for yolov5 after some tweaks etc. If I don’t do calibration, I only need to add a few lines to get quantized models.
Thanks @masahi. I have run it on the newer version and the issues are fixed. But compiling int8 is way slower than vanilla Pytorch Resnet50. And also, looks like the tensors are being converted back to float32 while compiling.
But looks like Relay is not able to take the quantized model. I did set keep_quantized_weight=True, but still the inference time is >2x slower than vanilla Pytorch FP32.