The accuracy of the compiled model declines! Is it a bug?

When the model xception-imagenet_origin.onnx was compiled, the accuracy of the compiled model decreases, from 0.78 to 0.77.

12 pictures have a inconsistent prediction results during 300 pictures. The figure is shown below in the form of : prediction top_1_class(top_1 class probability)

image

Besides, for the consistent results of the prediction, the probability value of the prediction also changes. I compared the prediction results of model before and after compilation using the statement np.testing.assert_allclose(res_onnx, res_tvm, rtol=1e-3, atol=1e-3) , and the outputs are as follows.

image

The reproducible script is following:

# -*- encoding=utf-8 -*-
import keras
import tvm
import tvm.relay as relay
import numpy as np
from PIL import Image
from tvm.contrib import graph_runtime
import onnx
import onnxruntime as rt

target = 'llvm -mcpu=core-avx2'
ctx = tvm.cpu(0)


def image_resize(x, shape):
    x_return = []
    for x_test in x:
        tmp = np.copy(x_test)
        img = Image.fromarray(tmp.astype('uint8')).convert('RGB')
        img = img.resize(shape, Image.ANTIALIAS)
        x_return.append(np.array(img))
    return np.array(x_return)


def getTestData():
    data = np.load("/share_container/data/dataset/imagenet-val-10.npz")
    x, y = data['x_test'], data['y_test']
    input_shape = (299, 299)
    input_precessor = keras.applications.xception.preprocess_input
    x_resize = image_resize(np.copy(x), input_shape)
    x_test = input_precessor(x_resize)
    y_test = keras.utils.to_categorical(y, num_classes=1000)
    return x_test, y_test


def compare(model_name):
    batch_size = 10  # test_pic number
    x_test, y_test = getTestData()  # get test images
    x_test = x_test[:batch_size]
    predict_model = onnx.load(model_name)

    sess = rt.InferenceSession(model_name)
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name
    res_onnx = sess.run([label_name], {input_name: x_test.astype(np.float32)})[0]

    # ############################ load keras model to relay IRModule  ##################################
    input_shape = (batch_size, 299, 299, 3)
    output_shape = (batch_size, 1000)

    shape_dict = {input_name: input_shape}

    irmod, params = relay.frontend.from_onnx(predict_model, shape_dict, freeze_params=True)
    irmod = relay.transform.DynamicToStatic()(irmod)

    # ########################## Compile the RelayIR ####################################################
    with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = relay.build_module.build(
            irmod, target=target, params=params)

    module = graph_runtime.create(graph, lib, ctx)

    data = x_test.astype('float32')
    module.set_input(input_name, data)
    module.set_input(**params)
    module.run()

    res_tvm = module.get_output(0, tvm.nd.empty(output_shape)).asnumpy()
    # ###########################  calc tvm accuracy   ###############################
    print("consistent prediction result:")
    print("onnx prediction   TVM prediction)
    for i in range(batch_size):
        ori_class = np.argmax(res_onnx[i])
        compile_class = np.argmax(res_tvm[i])
        if ori_class != compile_class:
            print( str(ori_class) + "(" + str(np.max(res_onnx[i])) + ")\t"
                  + str(compile_class) + "(" + str(np.max(res_tvm[i])) + ")\t " )

    np.testing.assert_allclose(res_onnx, res_tvm, rtol=1e-3, atol=1e-3)


if __name__ == '__main__':
    model_name = "xception-imagenet_origin.onnx"
    compare(model_name)




Anyone can access the model and a small dataset(10 pictures) with this link:

https://drive.google.com/drive/folders/1YpO_61A5Z0hycyF1Z8tjxhJj8FIgT3DB?usp=sharing

1 Like

@tqchen @FrozenGene @merrymercy @haichen

Is this slight decrease in accuracy considered a bug? If so, do we need to locate and fix this bug, even though it seems very difficult

seems same issue : After Tuning, the model has a inconsistent prediction result with Keras - #8 by FrozenGene

I think it is not the same problem, because I have fixed that bug using the patch you provided. In addition, this bug still exists when batch_size =1.

I tried to run 11 different onnx models, only this one has this problem, I think it may be the special structure of the xception-imagenet model that triggered the bug.

How about setting opt level be 2?

opt-level =1 or 2 still have the same error prediction result. It seems not related with opt_level.

OK. Should be one bug. You could file this.

Thanks, do I submit this bug to GitHub?

yes. As this is one issue.

OK,Thanks for your confirmation.