[QUANTIZATION] How to handle PyTorch model that expects quint8 input

With help from the TVM and PyTorch communities, I was able to figure out how to quantize PyTorch models such that they’ll accept torch.quint8 input, instead of the default torch.float32 + quant(). You can find the code snippet for that below.

import torch
from torch import nn

model = nn.Sequential(
     nn.Conv2d(2,64,3),
     nn.ReLU(),
     nn.Conv2d(64, 128, 3),
     nn.ReLU()
)

## FX GRAPH
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig

m_int = copy.deepcopy(model)
m_int.eval()

backend_int = "x86"  # running on a x86 CPU. Use "qnnpack" if running on ARM.
qconfig_dict_int = {"": torch.quantization.get_default_qconfig(backend_int)}
prepare_custom_config = PrepareCustomConfig() 
prepare_custom_config.set_input_quantized_indexes([0])  
prepare_custom_config.set_output_quantized_indexes([0])
# Prepare

example_input_int = torch.quantize_per_tensor(torch.rand(1,2,28,28), 1.0/255, 0, torch.quint8)

model_prepared_int = quantize_fx.prepare_fx(
    m_int, qconfig_dict_int,example_input_int, prepare_custom_config)

# Calibrate - Use representative (validation) data.
with torch.inference_mode():
  for _ in range(10):
    # x = torch.randint(255, [1,2,28,28])
    x = torch.rand(1,2,28,28)
    model_prepared_int(x)
# quantize
model_quantized_int = quantize_fx.convert_fx(model_prepared_int)

res_int = model_quantized_int(example_input_int)

The difference with such quantization settings can be see below when compared against the standard quantized model, especially their forward().

print(model_quantized_d)

print(model_quantized_int)

GraphModule(
  (0): QuantizedConvReLU2d(2, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.009286325424909592, zero_point=0)
  (2): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.004485216923058033, zero_point=0)
)

def forward(self, input):
    input_1 = input
    _0_input_scale_0 = getattr(self, "0_input_scale_0")
    _0_input_zero_point_0 = getattr(self, "0_input_zero_point_0")
    quantize_per_tensor = torch.quantize_per_tensor(input_1, _0_input_scale_0, _0_input_zero_point_0, torch.quint8);  input_1 = _0_input_scale_0 = _0_input_zero_point_0 = None
    _0 = getattr(self, "0")(quantize_per_tensor);  quantize_per_tensor = None
    _2 = getattr(self, "2")(_0);  _0 = None
    dequantize_2 = _2.dequantize();  _2 = None
    return dequantize_2
    
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
  (0): QuantizedConvReLU2d(2, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.009379507973790169, zero_point=0)
  (2): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.004618248902261257, zero_point=0)
)

def forward(self, input):
    input_1 = input
    _0 = getattr(self, "0")(input_1);  input_1 = None
    _2 = getattr(self, "2")(_0);  _0 = None
    return _2

However, TVM’s PyTorch frontend failed to recognize the quint8 dtype.

input_shape = [1, 2, 28, 28]
input_data = torch.rand(*input_shape)

input_name = "input0"
shape_list = [(input_name, input_shape)]

with torch.no_grad():
    trace = torch.jit.trace(model_quantized_int, example_input_int)
    torch_res = model_quantized_int(example_input_int)

mod, params = relay.frontend.from_pytorch(trace, shape_list)

The error messages:

---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
/var/folders/cm/8dfdl1jx59s_qthfgrk14l040000gr/T/ipykernel_58153/1090387853.py in <module>
----> 1 mod, params = relay.frontend.from_pytorch(trace, shape_list)

~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in from_pytorch(script_module, input_infos, custom_convert_map, default_dtype, use_parser_friendly_name, keep_quantized_weight, export_renamed_c_graph_path)
   4951     params = script_module.state_dict() if is_module else {}
   4952     outputs = _get_relay_input_vars(
-> 4953         graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
   4954     )
   4955 

~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in _get_relay_input_vars(graph, input_infos, prelude, is_module, default_dtype)
   4708     input_types = [
   4709         (name, get_relay_ty(info[0], info[1], gi.type()))
-> 4710         for (name, info), gi in zip(new_input_infos, graph_inputs)
   4711     ]
   4712 

~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in <listcomp>(.0)
   4708     input_types = [
   4709         (name, get_relay_ty(info[0], info[1], gi.type()))
-> 4710         for (name, info), gi in zip(new_input_infos, graph_inputs)
   4711     ]
   4712 

~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in get_relay_ty(ishape, itype, pt_type)
   4663                 pt_dtype = itype
   4664             dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
-> 4665             return TensorType(ishape, dtype)
   4666         elif pt_type.kind() == "TupleType":
   4667             if not isinstance(ishape, tuple):

~/workspace/OSS_TVM/tvm_0p11/python/tvm/ir/tensor_type.py in __init__(self, shape, dtype)
     39 
     40     def __init__(self, shape, dtype="float32"):
---> 41         self.__init_handle_by_constructor__(_ffi_api.TensorType, shape, dtype)
     42 
     43     @property

~/workspace/OSS_TVM/tvm_0p11/python/tvm/_ffi/_ctypes/object.py in __init_handle_by_constructor__(self, fconstructor, *args)
    143         # pylint: disable=not-callable
    144         self.handle = None
--> 145         handle = __init_by_constructor__(fconstructor, args)
    146         if not isinstance(handle, ObjectHandle):
    147             handle = ObjectHandle(handle)

~/workspace/OSS_TVM/tvm_0p11/python/tvm/_ffi/_ctypes/packed_func.py in __init_handle_by_constructor__(fconstructor, args)
    258         != 0
    259     ):
--> 260         raise get_last_ffi_error()
    261     _ = temp_args
    262     _ = args

TVMError: Traceback (most recent call last):
  File "/Users/***/workspace/OSS_TVM/tvm_0p11/include/tvm/runtime/packed_func.h", line 777
TVMError: In function ir.TensorType(0: Array<PrimExpr>, 1: DataType) -> relay.TensorType: error while converting argument 1: [01:16:56] /Users/***/workspace/OSS_TVM/tvm_0p11/include/tvm/runtime/data_type.h:383: unknown type quint8
1 Like

Looking at the frontend/pytorch.py: https://github.com/apache/tvm/blob/da7b48f9487c9ee8eb8c6a6d7f80b59969f842c8/python/tvm/relay/frontend/pytorch.py#L4543

Line 4521 - 4523 is where the issue appeared:

outputs = _get_relay_input_vars(
        graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
)

Line 4536 - 4543 starts to handle quantized models.

# For quantized models
    quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
    if len(quantized_ops.intersection(set(op_names))) > 0:
        weight_quant_params = qnn_torch.get_weight_quant_params(
            script_module, packed_param_map.values()
)
        qnn_torch.inline_input_quant_params_for_fx(graph, tensors)
        input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)

If some quantization-related processing happens before calling _get_relay_input_vars(), will that solve the problem? Thanks.

1 Like