Quantization in TVM
(This RFC corresponds to a PR #7474)
The goal of this work is to create a flexible and extensible framework for quantizing and calibrating models. Specifically, I want to
-
Allow arbitrary patterns to be rewritten to a corresponding quantized pattern
-
Support different, data-aware calibration methods, and allow new ones to be implemented easily
-
Easily be able to accommodate quantization to new datatypes in the future
I have broken the workflow down into three steps, quantization, calibration and requantization.
In quantization, I identify patterns in the original model that we want to quantize, and replace them with a quantized version of that pattern. I set the scale and zero points in qnn ops to relay variables, which will be set in calibration.
In calibration, I provide a callback through which users can set the scale and zero point variables to values, and run intermediate parts of the graph with real inputs to support data-aware calibration.
In requantization, I remove extraneous qnn.quantize and qnn.dequantize ops, and replace them with qnn.requantize. In calibration, I don’t insert any qnn.requantize ops because qnn.requantize requires scales and zero points to be constant scalars, not expressions, and postponing inserting qnn.requantize ops later allows quantization to be more modular. More on this in the Requantization section.
Quantization
In quantization, I use existing qnn ops to construct a quantized version of the graph.
There are two main classes involved in quantization: QuantizerPattern, a subclass of DFPatternCallback, and Quantizer. (DFPatternCallback finds specific patterns in a relay function, and transforms them using the pattern matcher).
The QuantizerPattern contains the pattern that we want to quantize, and also implements the callback method from the DFPatternCallback class to rewrite that pattern. For example, for the Conv2DPattern class rewrites
E0
fn (data, weight) {
%0 = nn.conv2d(data, weight)
}
as
E1
fn (data, weight, scale_var_0, zp_var_0, scale_var_1, zp_var_1) {
%0 = qnn.quantize(data, scale_var_0, zp_var_0)
%1 = qnn.quantize(weight, scale_var_1, zp_var_1)
%2 = qnn.conv2d(%0, %1, zp_var_0, zp_var_1, scale_var_0, scale_var_1)
%3 = qnn.dequantize(%2, scale_var_0 * scale_var_1, relay.const(0, dtype='int32'))
}
Here is a shorter version of Conv2DPattern, the QuantizerPattern that does this transformation:
E2
class Conv2DPattern(QuantizerPattern):
def __init__(self, calibration_callback):
self.calibration_callback = calibration_callback
super().__init__(calibration_callback)
self.input = wildcard()
self.conv_weight = wildcard()
self.inputs = [self.input, self.conv_weight]
self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
self.pattern = self.conv2d
self.attrs = None
self.weight_channel_axis = None
self.data_channel_axis = None
self.channels = None
def callback(self, pre, post, node_map):
self.args = [node_map[i][0] for i in self.inputs]
conv2d = node_map[self.conv2d][0]
self.out_dtype = conv2d.checked_type.dtype
self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape)
self.create_scale_zps("conv2d_data", "conv2d_weight")
self.quantize_args()
conv_scale = self.scale_zps[0] * self.scale_zps[2] # data_scale * weight_scale
# Conv zp is zero since QNN deals with input zps for us
conv_zp = relay.const(0, dtype="int32")
# args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
if self.padding is not None:
top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)]
if self.kernel_layout == "OIHW":
pad_width = ((0, 0), (0, 0), (top, bottom), (left, right))
elif self.kernel_layout == "HWIO":
pad_width = (
(top, bottom),
(left, right),
(0, 0),
(0, 0),
)
pad_val = 0
args[0] = relay.op.nn.pad(args[0], pad_width, pad_val)
# Construct quantized qnn.conv2d and dequantize
qnn_call = self.create_conv(args)
dequantized_call = relay.qnn.op.dequantize(
qnn_call, conv_scale, conv_zp, out_dtype=self.out_dtype, axis=self.data_channel_axis
)
return dequantized_call
def quantize_args(self):
"""Helper to quantize the arguments to the qnn.conv2d."""
quantized_data = relay.qnn.op.quantize(
self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
)
quantized_weight = relay.qnn.op.quantize(
self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis
)
self.quantized_args = [quantized_data, quantized_weight]
def create_conv(self, args):
"""Creates the qnn.conv2d.
Parameters
----------
args : List[relay.Expr]
Quantized arguments for the qnn.conv2d.
Returns
-------
q_conv2d : relay.Expr
Quantized version of the pattern.
"""
return relay.qnn.op.conv2d(*args, **self.attrs)
def get_kernel_size(self, kernel_shape, kernel_layout):
"""Body omitted for brevity, gets the kernel size"""
pass
def get_attrs(self, attrs, kernel_shape):
"""Body omitted for brevity, constructs attrs for qnn.conv2d"""
pass
There is a QuantizerPattern for every pattern we want to quantize. The patterns we currently support are Conv2DPattern, Conv2DBiasAddPattern, DensePattern, AddPattern, and MultiplyPattern, but it is easy to add your own if you wish to support a different pattern.
The Quantizer takes in the function to quantize, the parameters of the function, and a list of QuantizerPatterns. Let’s say we only want to quantize conv2d ops and dense ops. Then, we could create a Quantizer like this:
E3
quantizer = Quantizer(func, params, [Conv2DPattern(), DensePattern()])
Internally, the quantizer pattern first partitions the graph into functions containing each pattern, then rewrites the patterns to be quantized. It also constructs two functions which return tuples containing a lot of intermediate subgraphs, and stores indices mapping specific scale and zero point variables to these subgraphs, so they can be run and used in data-aware calibration.
For example, to pick values for scale_var_0, zp_var_0, scale_var_1 and zp_var_1 in E1, we might want to look at the values of %data, %weight and %0 (the output of the nn.conv2d0 in E0, as well as the values of %0 (the quantized data), %1 (the quantized weight) and %3 (the result of qnn.conv2d after converting back to float32) in E1. Relay doesn’t give us a good way to access intermediate values in functions, so I put all these values into tuples and return the tuple as the output of the function. For the original function in E0, we create a function whose output is (data, weight, %0) and for the quantized function in E1, we create a function whose output is (%0, %1, %3). For longer functions, the tuple would be a lot longer. For each pattern matched in the graph, we also store indices into the tuple so that we can extract the useful values during calibration.
These functions never have to built or indexed into by users. Utility functions in the calibrater do this automatically (more on this in the next section).
Calibration
Calibration involves four classes: QuantizerPattern, CalibrationInfo, CalibrationCallback, and Calibrator
Each QuantizerPattern has a method, calibrate_pattern, which is used during calibration to pick scale and zero point values.
calibrate_pattern returns a map of the names of scales and zero point variables to the value we are setting them as.
It takes CalibrationInfo as an argument. CalibrationInfo contains the names of scale and zero points variables for every qnn.quantize in the pattern in pairs: [(scale_var_1, zp_var_1), (scale_var_2, zp_var_2)]. CalibrationInfo also exposes the intermediate values in the graph through the methods get_unquantized_layer_inputs, get_unquantized_layer_outputs, get_quantized_layer_inputs, and get_quantized_layer_outputs. Each of these functions take an input to the original function, runs the quantized or unquantized function, and returns values corresponding AST nodes in the pattern.
For example, for the CalibrationInfo object corresponding to the pattern in E0 and E1, get_unquantized_layer_inputs returns values corresponding to [data, weight], and get_unquantized_layer_output returns the value corresponding to %0 in E0. get_quantized_layer_inputs returns values corresponding to [%0, %1] in E1, and get_quantized_layer_outputs returns %3 in E1.
The CalibrationInfo object also optionally contains a DatasetManager. The DatasetManager is a simple wrapper class for exposing datasets from other ML frameworks to the Calibrator in a unified way. For example, there is a TFDatasetManager, which wraps tensor flow dataset:
E4
class TFDatasetManager(DatasetManager):
"""DatasetManager wrapping a tensorflow dataset."""
def __init__(self, tf_dataset, batch_size, total_batches):
self.idx = 0
self.total_batches = total_batches
self.batch_size = batch_size
self.tf_dataset = tf_dataset
self.tf_iter = iter(self.tf_dataset)
def get_next_batch(self):
if self.is_empty():
raise IndexError
self.idx += 1
data, label = next(self.tf_iter)
return [data.numpy()], label.numpy()
def num_batches(self):
return self.total_batches
def batch_size(self):
return self.batch_size
def is_empty(self):
return self.idx >= self.total_batches
def reset(self):
self.tf_iter = iter(self.tf_dataset)
self.idx = 0
Inputs from the DatasetManager can be passed to get_quantized_layer_inputs, get_quantized_layer_outputs, get_unquantized_layer_inputs, and get_unquantized_layer_outputs.
Let’s look at writing a data-aware calibrate_pattern for MyConv2DPattern. We’ll use the DatasetManager to get inputs from the original function, and pass them to get_unquantized_layer_inputs to get the data and weight for the Conv2D op.
E5
class MyConv2dPattern(Conv2DPattern):
def calibrate_pattern(self, calibration_info):
scale_zp_values = {}
# Get an input to the original graph
inputs = calibration_info.dataset_manager.get_next_batch()
# Run the original function with the inputs and get values for data and weight in this pattern
data_value, weight_value = calibration_info.get_unquantized_layer_inputs(inputs)
# calibration_info.input_scale_zps = [[data_scale, data_zp], [weight_scale, weight_zp]]
data_scale_name = calibration_info.input_scale_zps[0][0].name_hint
data_scale = np.max(data_value) / 128
scale_zp_values[data_scale_name] = data_scale
# ...
# Set all the other scales and zero points
# ...
# scale_zp_values would look something like {'data_scale': 0.02, 'data_zp': 0, 'weight_scale': 0.05, 'weight_zp': 0.1}
return scale_zp_values
Note: In E5 (and E8) I only use one input from the DatasetManger for the sake of simplicity, however in most data-aware algorithms, we will use many different inputs to the graph to calculate a lot of intermediate values, which will be used to calculate scales and zero points.
Being able to write pattern specific calibrate_pattern methods gives us more flexibility in constructing scales and zero points. To create per channel scales, we need to know the number of channels a Conv2D op has, and the number of units a Dense op has.
In E5, however, we’re not actually using any pattern specific information. We’ve written calibrate_pattern assuming that there are two values that are being quantized, data and weight, and two corresponding qnn.quantize ops. This is true for the Conv2DPattern (see E1), the DensePattern, and any other binop we want to quantize. If we want to implement the same method on the DensePattern, we would have to copy Conv2D’s calibrate pattern into DensePattern.
To reduce code reuse, we define a class called CalibrationCallback, which also has a method called calibrate_pattern, Each QuantizerPattern optionally takes in a CalibrationCallback as an argument, and its calibrate_pattern calls the calibrate_pattern of the CalibrationCallback: E6
class QuantizerPattern(DFPatternCallback):
def __init__(self, calibration_callback):
super().__init__()
self.calibration_callback = calibration_callback
def calibrate_pattern(self, calibration_info):
return self.calibration_callback.calibrate_pattern(calibration_info)
So, if we don’t overwrite the QuantizerPattern’s calibrate_pattern method, we’ll call the calibrate_pattern method of CalibrationCallback that is passed in. Let’s take a look at what using CalibrationCallbacks looks like:
E7
cc = MyCalibrationCallback()
conv2d_pattern = Conv2DPattern(cc)
dense_pattern = DensePattern(cc)
quantizer = Quantizer(func, params, [conv2d_pattern, dense_pattern])
Now let’s implement MyCalibrationCallback. This time, we’ll write the calibrate_pattern method to be generic so that it supports any number of qnn.quantize ops in the pattern, and can be passed to any QuantizerPattern:
E8
class MyCalibrationCallback(CalibrationCallback):
def __init__(self, dataset_manager):
self.dataset_manager = dataset_manager
def calibrate_pattern(self, calibration_info):
scale_zp_values = {}
inputs = calibration_info.dataset_manager.get_next_batch()
# quantized_values = [quantized_value_1, quantized_value_2, ... quantize_value_n]
quantized_values = calibration_info.get_unquantized_layer_inputs(inputs)
# calibration_info.input_scale_zps = [[scale_var_1, zp_var_1], [scale_var_2, zp_var_2], .., [scale_var_n, zp_var_n]]
for i in range(len(calibration_info.input_scale_zps)):
scale_name = calibration_info.input_scale_zps[i][0].name_hint
zp_name = calibration_info.input_scale_zps[i][1].name_hint
# Calculate simple scale and zero point values
scale_zp_values[scale_name] = np.max[quantized_values[i]] / 128
scale_zp_values[zp_name] = np.mean[quantized_values[i]] / 128
return scale_zp_values
The Calibrator class manages calibration at a high level, and calls calibrate_pattern. It maintains a list of all the scales and zero point values returned from calibrate_pattern also updates the CalibrationInfo that is passed to calibrate_pattern. It takes in a Quantizer as an argument, since it needs to access information from the Quantizer. It also optionally takes a DatasetManager, which is passed to the calibrate_pattern function through the CalibrationInfo object.
So, to calibrate a function completely, all you need to do is construct a Calibrator, and call the method calibrate:
E9
cc = MyCalibrationCallback()
conv2d_pattern = Conv2DPattern(cc)
dense_pattern = DensePattern(cc)
quantizer = Quantizer(func, params, [conv2d_pattern, dense_pattern])
calibrator = QuantizationCalibrator(quantizer)
calibrated_func = calibrator.calibrate()
Requantization
qnn.requantize takes an int8 value, some scale and zero points, and transforms it into another int8 value with different scale and zero points, without going back to float32. To get a fast quantized workload, we want to stay in int8 for as long as possible. In the quantization step, we don’t use any qnn.requantize ops. We only use qnn.quantize and qnn.dequantize in that step for three reasons:
-
qnn.requantize requires scales and zero points to be constants, not relay expressions. We can’t allow the scales and zero points to be expressions without sacrificing performance.
-
If we were to directly introduce qnn.requantize during the quantization step, we would not be able to quantize each pattern individually because qnn.requantize requires scale and zero point values from the next pattern.
-
For quantization methods like KL-divergence, it is useful to have access to output value of the quantized layer, so we can compare it directly to the original output value. For example, we want to be able to compare %3 from E1, and compare it to %0 from E0, and adjust our scales and zero points so that the values are as close as possible. If %3 in E1 were a qnn.requantize instead of a qnn.dequantize, we could only compare the output of the qnn.requantize, which quantized, so the comparison is not useful.
Here’s how we requantize:
E10
cc = MyCalibrationCallback()
conv2d_pattern = Conv2DPattern(cc)
dense_pattern = DensePattern(cc)
quantizer = Quantizer(func, params, [conv2d_pattern, dense_pattern])
calibrator = QuantizationCalibrator(quantizer)
calibrated_func = calibrator.calibrate()
requantized_func = Requantizer().requantize(calibrated_func)
End use
Since this framework is designed to be flexible and modular, there are a lot of different parts that an end user probably does not want to deal with. We provide a Relay function transformation pass that wraps quantization, calibration and requantization together. The user only has to specify the QuantizerPatterns they want to use.
However, advanced users can call the workflow directly, or combine different parts of the workflow to create new relay function passes.
Output on a simple MNIST graph
Let’s look at calibrating as small MNIST graph:
E11
cc = AverageMaxCalibrationCallback()
quantizer = Quantizer(mnist_func, params, [Conv2DBiasAddPattern(cc), Conv2DPattern(cc), DensePattern(cc), AddPattern(cc), MultiplyPattern(cc)], skip_first=False)
calibrator = QuantizationCalibrator(quantizer, target='llvm', ctx=tvm.cpu(), dataset_manager=mnist_train_manager)
calibrated_func = calibrator.calibrate()
calibrated_mod = tvm.ir.IRModule.from_expr(calibrated_func)
requantized_func = Requantizer().requantize(calibrated_func)
E12 mnist_func
fn (%flatten_input: Tensor[(5, 28, 28, 1), float32], %dense_1/kernel:0: Tensor[(128, 10), float32], %dense_1/bias:0: Tensor[(10), float32], %dense/kernel:0: Tensor[(784, 128), float32], %dense/bias:0: Tensor[(128), float32]) {
%0 = nn.batch_flatten(%flatten_input);
%1 = transpose(%dense/kernel:0, axes=[1, 0]);
%2 = nn.dense(%0, %1, units=None);
%3 = add(%2, %dense/bias:0);
%4 = nn.relu(%3);
%5 = transpose(%dense_1/kernel:0, axes=[1, 0]);
%6 = nn.dense(%4, %5, units=None);
%7 = add(%6, %dense_1/bias:0);
nn.softmax(%7)
}
E13 MNIST model after quantization, calibration and requantization:
fn (%flatten_input: Tensor[(5, 28, 28, 1), float32]) -> Tensor[(5, 10), float32] {
%0 = nn.batch_flatten(%flatten_input) /* ty=Tensor[(5, 784), float32] */;
%1 = qnn.quantize(%0, 0.00390625f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="int8") /* ty=Tensor[(5, 784), int8] */;
%2 = qnn.quantize(meta[relay.Constant][0] /* ty=Tensor[(128, 784), float32] */, 0.00453253f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(128, 784), int8] */;
%3 = qnn.dense(%1, %2, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.00390625f /* ty=float32 */, 0.00453253f /* ty=float32 */, units=128, out_dtype="int32") /* ty=Tensor[(5, 128), int32] */;
%4 = qnn.quantize(meta[relay.Constant][1] /* ty=Tensor[(128), float32] */, 0.00390625f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(128), int32] */;
%5 = nn.bias_add(%3, %4) /* ty=Tensor[(5, 128), int32] */;
%6 = qnn.requantize(%5, 1.77052e-05f /* ty=float32 */, 0 /* ty=int32 */, 0.0267685f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int8");
%7 = nn.relu(%6);
%8 = qnn.quantize(meta[relay.Constant][2] /* ty=Tensor[(10, 128), float32] */, 0.00579835f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(10, 128), int8] */;
%9 = qnn.dense(%7, %8, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0267685f /* ty=float32 */, 0.00579835f /* ty=float32 */, units=10, out_dtype="int32");
%10 = qnn.quantize(meta[relay.Constant][3] /* ty=Tensor[(10), float32] */, 0.0267685f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(10), int32] */;
%11 = nn.bias_add(%9, %10);
%12 = qnn.dequantize(%11, 0.000155213f /* ty=float32 */, 0 /* ty=int32 */, axis=1);
nn.softmax(%12)
}
(Note that I didn’t skip the first or last pattern when quantizing this model, but you can if you want to).