[ONNX] Incorrect inference results for prequantized model

Hi all,

I was trying to compile and run a prequantized mobilenetv2 model from ONNX using TVM. When I compare its result with those of ONNX Runtime, they are different.

The top 5 decoded predictions are similar according to both TVM and ONNX Runtime, but ordered differently.

The network’s output tensors really are significantly different. According to numpy’s allclose:

Not equal to tolerance rtol=0.001, atol=1e-06

Mismatched elements: 930 / 1000 (93%)
Max absolute difference: 2.5623133
Max relative difference: 11.
 x: array([[-3.84347 , -1.121012, -5.605061, -3.523181, -4.804338, -4.484048,
        -4.323904, -2.242024, -1.601446,  0.160145,  0.960868, -0.160145,
         2.08188 , -2.722458, -3.523181, -0.160145, -1.921735, -2.882602,...
 y: array([[-2.722458, -1.921735, -6.405783, -4.484048, -5.124627, -4.163759,
        -3.523181, -3.363036, -2.242024, -0.160145,  0.      , -0.480434,
         1.76159 , -3.683325, -3.84347 , -0.160145, -2.242024, -3.683325,...

The following script should be able to reproduce this output.

from PIL import Image
import numpy as np
import onnx
import onnxruntime as ort
import os
import tvm
import urllib.request

image_url = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
model_url = 'https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-12-int8.onnx'

model_path = '/tmp/mobilenetv2-12-int8.onnx'
image_path = '/tmp/kitten.jpg'

if not os.path.exists(image_path):
    urllib.request.urlretrieve(image_url, image_path)
if not os.path.exists(model_path):
    urllib.request.urlretrieve(model_url, model_path)


# Apply input preprocessing
input_image = Image.open(image_path).resize((224, 224))
input_tensor = np.asarray(input_image)
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_stddev = np.array([0.229, 0.224, 0.225])
input_tensor = ((input_tensor / 255.) - imagenet_mean) / imagenet_stddev
input_tensor = np.transpose(input_tensor, (2, 0, 1))
input_tensor = np.expand_dims(input_tensor, axis=0).astype(np.float32)

# ONNX Runtime evaluation
onnx_model = onnx.load(model_path)
session = ort.InferenceSession(onnx_model.SerializeToString())
input_name = session.get_inputs()[0].name
onnx_result = session.run(None, {input_name: input_tensor})[0]

# TMV evaluation
shape_dict = { input_name: (1, 3, 224, 224) }
model, params = tvm.relay.frontend.from_onnx(onnx_model, shape=shape_dict)

with tvm.transform.PassContext(opt_level=0):
    model = tvm.relay.transform.InferType()(model)
    model = tvm.relay.qnn.transform.CanonicalizeOps()(model)
    lib = tvm.relay.build_module.build(model, target='llvm', params=params)

module = tvm.contrib.graph_executor.GraphModule(lib['default'](tvm.cpu()))
module.set_input(input_name, input_tensor)
module.run()
tvm_result = module.get_output(0).numpy()


# Comparison
def extract_predictions(tensor):
    return np.argsort(np.squeeze(tensor))[-5:][::-1]

tvm_pred = extract_predictions(tvm_result)
onnx_pred = extract_predictions(onnx_result)
print(f'Expected predictions: {onnx_pred}')
print(f'Actual predictions: {tvm_pred}')
np.testing.assert_allclose(tvm_result, onnx_result, rtol=1e-3, atol=1e-6)

Are these results something to be expected for prequantized models? The results for the floating point models I’ve built with TVM so far were all significantly closer to the expected results. Or did I make a mistake in the code to build/evaluate the model using TVM?

I’ve also encountered a similar issue while attempting to build PP-OCR detection model from ONNX.

As you can see, the output from Relay model (right) does not resemble the original output (left) and have periodic artifacts. Here is a link to Colab.

Looking through the forum, this seems more of an issue specific to ONNX, rather than an issue with pre-quantized models (but I might be wrong).

I have also encountered the same problem. I see people have responded, but no one seems to have responded and provided a solution.