It seems can be fixed by adding a “0.5”, following an integer division. What make things complicated is, Torch applies banker’s rounding, that is, Rounding to Nearest Even(RNE):
RNE(1.5) → 2
RNE(2.5) → 2
Thus, a bigger problem is, Torch uses RNE as its standard rounding algorithm, e.g.
Four parameters of q_multiply_shift: x, y, q, s, let’s assume q = 31, s = -1, and a value of x * y as follows:
2nd rounding
| 1st rounding
| |
V v
bit idx: 63 62 ... 32 31 30 ... 0
bit val: s s 0 1 # s for sign bit,
# { bit_62 - bit_31 } will be kept
# as the result (a 32bit integer)
# of neon.sqrdmulh
The first rounding will produce a carry:
2nd rounding
|
V
bit idx: 62 ... 32 31
bit val: s 1 # The 2nd rounding is done by neon.srshl
# (shift left "s = -1" with rounding)
# Note: bit_32 will accept the "carry"
# out of "rounding(bit_31...)"
This carry will then be further propagated to bit 32, which produces a different value to the DEFAULT path(aka rounding once).
I’ve gave up matching exactly against PT result during I was developing quantized PT model support. But that might have been due to the fact that I was using fbgemm exclusively… Are you saying it is feasible to be bit-exact on ARM + qnnpack?
Sorry it’s not clear to me if you are proposing something or describing the situation in general. Do you want us to do something?
We’ve managed to get an result (using our TorchScript, QAT model) at precision of rtol=1e-05, atol=1e-08, against torch’s qnnpack, with the above-mentioned matters(except RNE) considered
I’m expecting to:
Submit “two rounding precision fixup patch” later, request for review (not sure the impact for TFLite bit-exact)
Raise a question: shall we switch to RNE (Rounding to Nearest Even)?
Pros:
Torch is quite popular these days
Under RNE, 0.5 has a 50-50 chance to lead to carry, bringing in lesser bias, i.e.
Value
Whether lead to carry
>0.5
yes
<0.5
no
==0.5
50%
Cons:
Limited HW support (?)
Modify q_multiply_shift on each device
Also need to apply RNE to other Ops, where rounding is involved, (e.g. AdaptiveAvgPool2d)