Introduction and motivation
Mathematically, the fixed point multiplication (FPM) can be described as:
fpm(x,m,s) = round(x*m*2^(s-31))
In this expression:
-
x
is the quantized value to multiply, andm
ands
are an integer multiplier and a shift. -
The function
round
can be any of the rounding rules described here.
FPM is at the heart of the requantization process in quantized neural networks (QNNs), where the 32-bit integers resulting from a convolution or GEMM need to be requantized to a narrower data type (usually int8 or uint8).
Our analysis shows that we can achieve up to a 3% improvement, on Arm targets, by speeding up FPM. Even though it might not seem a lot, in a previous RFC we showed that we are now 4% away from frameworks like TFlite, so even a tiny 3% improvement is appealing for us.
Background
In its current state, TVM implements FPM as a sequence of relay operators. The pseudo-code is showed below:
def fixed_point_multiply(x, fixed_point_multiplier, right_shift)
x = cast(x,int64) * fixed_point_multiplier
total_right_shift = right_shift + 31
pos_rounding_value = 1 << (total_right_shift -1)
x = x + pos_rounding_value
x = x >> total_right_shift return cast(x, int32)
- All the operators (shift, sum, multiplication) are Relay operators
- All the computation is mostly carried in 64 bits, converting to 32 bits only at the very end of the FPM operator and is very close to the mathematical expression described above
- TVM picks a
to-nearest
rounding rule and breaks ties upward (i.e.,x.5
becomesx+1
). - The Relay implementation also considers the case of a negative right shift (not showed in the pseudo-code)
However, architectures like Armv8-A provide interesting instructions to execute this operation directly in 32 bits. In particular, it can be shown that this operation can be achieved (on Armv8-A targets) as a combination of sqrdmulh and srshl instructions (which indeed operate on 32bits quads). In particular:
-
sqrdmulh(a,b)
: executes((a*b*2)+round_const) * 2^(-31)
. Note that theround_const
is used to round to nearest breaking ties upward -
srshl(a,n)
: executesa*2^(-n)
, rounding always upward (this means we need to nudge the result to round to-nearest).
Design and implementation
We propose to create a TVM intrinsic qmuls
written in TVM IR (TIR) that will execute a Q-multiplication followed by a right shift.
The intrinsic signature is as follows:
qmuls(x, y, q, s)
Where x
and y
are two Q-numbers, and Q is passed as third argument. The right shift s
is passed as last argument to the intrinsic.
There are multiple reasons to introduce this:
- It is general enough, so it can be reused whenever we need to multiply Q-numbers (shift can be set to zero if we want to achieve only a Q-multiplication)
- Each hardware vendor can provide an hardware specific implementation for the operation
- The intrinsic can be overloaded by different targets using
tvm.target.intrin.register_intrin_rule
. This is a simpler approach than overloading through compute strategy or tensorization.
In the sections below, we describe the main code changes of this RFC.
Relay changes
We created a new Relay operator fixed_point_multiplication
and registered a compute and an injective_schedule
for it.
- The Relay operator has two attributes, the multiplier (
m
) and the right shift(s
) - The compute is a simple loop over the array (i.e., mostly like a unary operation)
- The injective schedule has the task to vectorize the loop.
TIR changes
The main TIR changes are the following:
- We registered a
tvm.intrin.rule.default.qmuls
TVM intrinsic that executes the same operations or the Relay implementation(but using TIR operators). - We created a TIR operator
qmuls(x,y,q,s)
which executes the call:
call_intrin(x.dtype, "qmuls", x, y, q, s)
Intrinsic overload
In order to overload the intrinsic for Armv8-A we need to make use of tvm.target.intrin.register_intrin_rule
. However, the intrinsics are overloaded by target_name
which in case of Armv8-A is only llvm
.
This means that, in order to specialize for llvm.aarch64
we had to hack into lower_intrin.cc and register a new llvm.intrin.rule.aarch64.
pattern.
Given the above tweak, we could easily exploit the tvm.target.intrin.register_intrin_rule
method in order to register a version of qmuls
tailored for Armv8-A ISA. The result is similar to the following:
def _qmuls_arm(op):
x = op.args[0]
multiplier = op.args[1]
shift = op.args[2]
sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.sqrdmulh', tvm.tir.const(2, 'uint32'), x, multiplier)
fixup = (sqrdmulh & (-shift)) >> 31
fixed_up_x = (sqrdmulh + fixup)
out = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.srshl', tvm.tir.const(2, 'uint32'), sqrdmulh, shift)
return out
tvm.target.intrin.register_intrin_rule("llvm.aarch64", "fixed_point_multiply", _fixed_point_multiply_arm, override=True)
Few notes on the above implementation:
- Please note that we also consider the case of a negative right shift (not showed in the code)
- The fixup is needed to round to nearest (instead of rounding upward as
srshl
does) - We use the default implementation when the data (
x
) is not a vector or when Q is not 31
Final notes on performance and precision
Performance
As previously mentioned, the best performance gain in using those intrinsics seems to set around 3%, but the performance improvement we got is only around 1.5%:
- The 3% improvement is for a case in which the requantization operation is fused within the main computation loop (e.g., GEMM or spatial convolution).
- In TVM, a quantized convolution is lowered as a sequence of a qnn.conv2d followed by a requantize operator. This makes fusing requantization within the compute not possible, explaining why we cannot fully achieve the 3% improvement.
Precision
There are corner cases in which the intrinsic implementation will have a +1/-1 error compared to the default TVM implementation. This is because we are rounding twice instead than once. In other words:
- default behavior:
out = round(x*y*2^-s)
- arm behavior:
out = round(round(x*y)*2^-s
)
PR
The PR for this RFC is here: https://github.com/apache/incubator-tvm/pull/5980