With help from the TVM and PyTorch communities, I was able to figure out how to quantize PyTorch models such that they’ll accept torch.quint8
input, instead of the default torch.float32
+ quant()
. You can find the code snippet for that below.
import torch
from torch import nn
model = nn.Sequential(
nn.Conv2d(2,64,3),
nn.ReLU(),
nn.Conv2d(64, 128, 3),
nn.ReLU()
)
## FX GRAPH
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
m_int = copy.deepcopy(model)
m_int.eval()
backend_int = "x86" # running on a x86 CPU. Use "qnnpack" if running on ARM.
qconfig_dict_int = {"": torch.quantization.get_default_qconfig(backend_int)}
prepare_custom_config = PrepareCustomConfig()
prepare_custom_config.set_input_quantized_indexes([0])
prepare_custom_config.set_output_quantized_indexes([0])
# Prepare
example_input_int = torch.quantize_per_tensor(torch.rand(1,2,28,28), 1.0/255, 0, torch.quint8)
model_prepared_int = quantize_fx.prepare_fx(
m_int, qconfig_dict_int,example_input_int, prepare_custom_config)
# Calibrate - Use representative (validation) data.
with torch.inference_mode():
for _ in range(10):
# x = torch.randint(255, [1,2,28,28])
x = torch.rand(1,2,28,28)
model_prepared_int(x)
# quantize
model_quantized_int = quantize_fx.convert_fx(model_prepared_int)
res_int = model_quantized_int(example_input_int)
The difference with such quantization settings can be see below when compared against the standard quantized model, especially their forward()
.
print(model_quantized_d)
print(model_quantized_int)
GraphModule(
(0): QuantizedConvReLU2d(2, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.009286325424909592, zero_point=0)
(2): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.004485216923058033, zero_point=0)
)
def forward(self, input):
input_1 = input
_0_input_scale_0 = getattr(self, "0_input_scale_0")
_0_input_zero_point_0 = getattr(self, "0_input_zero_point_0")
quantize_per_tensor = torch.quantize_per_tensor(input_1, _0_input_scale_0, _0_input_zero_point_0, torch.quint8); input_1 = _0_input_scale_0 = _0_input_zero_point_0 = None
_0 = getattr(self, "0")(quantize_per_tensor); quantize_per_tensor = None
_2 = getattr(self, "2")(_0); _0 = None
dequantize_2 = _2.dequantize(); _2 = None
return dequantize_2
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
(0): QuantizedConvReLU2d(2, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.009379507973790169, zero_point=0)
(2): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.004618248902261257, zero_point=0)
)
def forward(self, input):
input_1 = input
_0 = getattr(self, "0")(input_1); input_1 = None
_2 = getattr(self, "2")(_0); _0 = None
return _2
However, TVM’s PyTorch frontend failed to recognize the quint8
dtype.
input_shape = [1, 2, 28, 28]
input_data = torch.rand(*input_shape)
input_name = "input0"
shape_list = [(input_name, input_shape)]
with torch.no_grad():
trace = torch.jit.trace(model_quantized_int, example_input_int)
torch_res = model_quantized_int(example_input_int)
mod, params = relay.frontend.from_pytorch(trace, shape_list)
The error messages:
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
/var/folders/cm/8dfdl1jx59s_qthfgrk14l040000gr/T/ipykernel_58153/1090387853.py in <module>
----> 1 mod, params = relay.frontend.from_pytorch(trace, shape_list)
~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in from_pytorch(script_module, input_infos, custom_convert_map, default_dtype, use_parser_friendly_name, keep_quantized_weight, export_renamed_c_graph_path)
4951 params = script_module.state_dict() if is_module else {}
4952 outputs = _get_relay_input_vars(
-> 4953 graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
4954 )
4955
~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in _get_relay_input_vars(graph, input_infos, prelude, is_module, default_dtype)
4708 input_types = [
4709 (name, get_relay_ty(info[0], info[1], gi.type()))
-> 4710 for (name, info), gi in zip(new_input_infos, graph_inputs)
4711 ]
4712
~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in <listcomp>(.0)
4708 input_types = [
4709 (name, get_relay_ty(info[0], info[1], gi.type()))
-> 4710 for (name, info), gi in zip(new_input_infos, graph_inputs)
4711 ]
4712
~/workspace/OSS_TVM/tvm_0p11/python/tvm/relay/frontend/pytorch.py in get_relay_ty(ishape, itype, pt_type)
4663 pt_dtype = itype
4664 dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
-> 4665 return TensorType(ishape, dtype)
4666 elif pt_type.kind() == "TupleType":
4667 if not isinstance(ishape, tuple):
~/workspace/OSS_TVM/tvm_0p11/python/tvm/ir/tensor_type.py in __init__(self, shape, dtype)
39
40 def __init__(self, shape, dtype="float32"):
---> 41 self.__init_handle_by_constructor__(_ffi_api.TensorType, shape, dtype)
42
43 @property
~/workspace/OSS_TVM/tvm_0p11/python/tvm/_ffi/_ctypes/object.py in __init_handle_by_constructor__(self, fconstructor, *args)
143 # pylint: disable=not-callable
144 self.handle = None
--> 145 handle = __init_by_constructor__(fconstructor, args)
146 if not isinstance(handle, ObjectHandle):
147 handle = ObjectHandle(handle)
~/workspace/OSS_TVM/tvm_0p11/python/tvm/_ffi/_ctypes/packed_func.py in __init_handle_by_constructor__(fconstructor, args)
258 != 0
259 ):
--> 260 raise get_last_ffi_error()
261 _ = temp_args
262 _ = args
TVMError: Traceback (most recent call last):
File "/Users/***/workspace/OSS_TVM/tvm_0p11/include/tvm/runtime/packed_func.h", line 777
TVMError: In function ir.TensorType(0: Array<PrimExpr>, 1: DataType) -> relay.TensorType: error while converting argument 1: [01:16:56] /Users/***/workspace/OSS_TVM/tvm_0p11/include/tvm/runtime/data_type.h:383: unknown type quint8