【AutoTVM】After Tune, The compiled model lose the predict ability

After tuning, and use the with autotvm.apply_graph_best() to compile the model, the compiled model can not predict correctly, and why?

when I removed the sentence with autotvm.apply_graph_best(graph_opt_sch_file), the compiled model have a same predict result with origin model.

Note that, for reducing the tuning time, we only run 2 times tuning. and set runner=autotvm.LocalRunner(number=1, repeat=1, min_repeat_ms=2), most of the other code is similar with the tutorial https://tvm.apache.org/docs/tutorials/autotvm/tune_relay_x86.html#sphx-glr-tutorials-autotvm-tune-relay-x86-py

When I compile other model with TVM, It also have a worng predit results , bug seems not related with the model.

the code is as following

# -*- encoding=utf-8 -*-
import keras
import os
import tvm
import tvm.relay as relay
import numpy as np
from PIL import Image
from tvm.contrib import graph_runtime
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
from tvm import autotvm

import sys
import codecs
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())


base_dir = '/share_container/data/'

def image_resize(x, shape):
    x_return = []
    for x_test in x:
        tmp = np.copy(x_test)
        img = Image.fromarray(tmp.astype('uint8')).convert('RGB')
        img = img.resize(shape, Image.ANTIALIAS)
        x_return.append(np.array(img))
    return np.array(x_return)


def getTestData(project, dataset_dir):
    data_path = os.path.join(dataset_dir, "imagenet-val-1500.npz")
    data = np.load(data_path)
    x, y = data['x_test'], data['y_test']

    input_shape = (224, 224)
    input_precessor = keras.applications.resnet50.preprocess_input

    x_resize = image_resize(np.copy(x), input_shape)
    x_test = input_precessor(x_resize)
    y_test = keras.utils.to_categorical(y, num_classes=1000)
    return x_test, y_test, input_shape


def tune_kernels(tasks,
                 measure_option,
                 tuner='gridsearch',
                 early_stopping=None,
                 log_filename='tuning.log'):

    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(task, loss_type='rank')
        elif tuner == 'ga':
            tuner_obj = GATuner(task, pop_size=50)
        elif tuner == 'random':
            tuner_obj = RandomTuner(task)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(task)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        # do tuning
        # n_trial=len(task.config_space)
        n_trial = 2  # replace n_tral = 2 to reduce running time;
        tuner_obj.tune(n_trial=n_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(n_trial, prefix=prefix),
                           autotvm.callback.log_to_file(log_filename)])


def tune_graph(graph, dshape, records, opt_sch_file, target, use_DP=True, input_tensor="input_1"):
    target_op = [relay.op.get("nn.conv2d"),]
    Tuner = DPTuner if use_DP else PBQPTuner
    executor = Tuner(graph, {input_tensor: dshape}, records, target_op, target)
    executor.benchmark_layout_transform(min_exec_num=2)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)


def compare(model_path, is_tuning=True):
    # ######################## before compile ##########################################################
    project_name = model_path.split('/')[-1].split('-')[0]
    x_test, y_test , input_pic_size = getTestData(project_name, base_dir + "dataset")
    predict_model = keras.models.load_model(model_path)
    predict_model.summary()
    res_frame = predict_model.predict(x_test[:100])    # use 10 pictures to test
    print(res_frame)
    print(res_frame.shape)
    acc = 0
    for i in range(100):
        # print(np.argmax(res_frame[i]), np.argmax(y_test[i]))
        if np.argmax(res_frame[i]) == np.argmax(y_test[i]):
            acc += 1
    print("acc=", acc /100.0)
#    return

    # ############################ load keras model to relay IRModule  ##################################
    batch_size = 10                 # use 10 pictures to test
    input_shape = (batch_size, 3, input_pic_size[0], input_pic_size[1])
    output_shape = (batch_size, 1000)

    input_tensor = predict_model.input.name.split(':')[0]
    shape_dict = {input_tensor: input_shape}
    target = 'llvm -mcpu=core-avx2'
    ctx = tvm.cpu(0)

    irmod, params = relay.frontend.from_keras(predict_model, shape_dict)

    # ######################### Run tuning for some operators #############################################
    print('\033[1;35m build a task \033[0m')
    tuning_option = {
        'log_filename': log_file,
        'tuner': 'random',
        'early_stopping': None,

        'measure_option': autotvm.measure_option(
            builder=autotvm.LocalBuilder(),
            runner=autotvm.LocalRunner(number=1, repeat=1,
                                       min_repeat_ms=2),
        ),
    }

    tasks = autotvm.task.extract_from_program(irmod["main"], target=target,
                                              params=params,
                                              ops=(relay.op.get("nn.conv2d"),))
    print('\033[1;35m', log_file, ' ', graph_opt_sch_file, '\033[0m')
    print('\033[1;35m run tuning tasks \033[0m')
    if not os.path.exists('./' + log_file):
        tune_kernels(tasks, **tuning_option)
    print('\033[1;35m After tune_kernels \033[0m')
    if not os.path.exists('./' + graph_opt_sch_file):
        tune_graph(irmod["main"], input_shape, log_file, graph_opt_sch_file, target=target, input_tensor=input_tensor)

    # ########################## Compile the RelayIR ####################################################
    print('\033[1;35m Compile...\033[0m')
    with autotvm.apply_graph_best(graph_opt_sch_file):
        with tvm.transform.PassContext(opt_level=3):
            graph, lib, params = relay.build_module.build(
                irmod, target=target, params=params)

        module = graph_runtime.create(graph, lib, ctx)
        test_x_tvm = x_test[:batch_size].transpose([0, 3, 1, 2])
        dtype = 'float32'
        data = test_x_tvm.astype(dtype)
        module.set_input(input_tensor, data)
        module.set_input(**params)
        module.run()
        res_tvm = module.get_output(0, tvm.nd.empty(output_shape)).asnumpy()
        print(res_tvm)
        print(res_tvm.shape)
        # ###########################  calc diff   #################################################
        print('\033[1;35m Begin Calculate abs_diff\033[0m!')
        abs_diff = 0
        for i in range(10):
            ori_class = np.argmax(res_frame[i])
            compile_class = np.argmax(res_tvm[i])
            if ori_class != compile_class:
                print("[error]: unequal compiling result:  \n"
                      + str(ori_class) + "(" + str(np.max(res_frame[i])) + ")\t"
                      + str(compile_class) + "(" + str(np.max(res_tvm[i])) + ")\t "
                      + "label:", np.argmax(y_test[i]))


if __name__ == '__main__':
    model_path = base_dir + "origin_model/"
    model_name = "resnet50-imagenet_origin.h5"
    model_path +=  model_name
    print(model_path)

    log_file = "%s.log" % model_name
    graph_opt_sch_file = "%s_graph_opt.log" % model_name
    compare(model_path)

the results are as follows:

format: origin_model_predict_result(confidence) compiled_model_predict_result(confidence) label

image

Hope someone can help me, Thanks in advance.