After Tuning, the model has a inconsistent prediction result with Keras

When I ignore the auto tuning, TVM and Keras have a consistent prediction result. However, After tuning the the operator with tuning time 2, the compiled model lost the ability of predicting.

WARNING:root:Could not find any valid schedule for task Task(func_name=layout_transform, args=(('TENSOR', (10, 4, 28, 28, 128), 'float32'), 'NCHW128c', 'NCHW2c'), kwargs={}, workload=('layout_transform', ('TENSOR', (10, 4, 28, 28, 128), 'float32'), 'NCHW128c', 'NCHW2c')). A file containing the errors has been written to /tmp/tvm_tuning_errors_irwzqkj2.log.

The /tmp/tvm_tuning_errors_irwzqkj2.log is following:

This is the result that comparing the prediction result beteen Keras and model TVM with auto tuning. image

Script are following

import keras
import os
import tvm
import tvm.relay as relay
import numpy as np
from PIL import Image
from tvm.contrib import graph_runtime
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
from tvm import autotvm

def getTestData(pic_number):
    x_return = []
    pic = Image.open("test_pic.png")
    pic = np.array(pic)
    for i in range(pic_number):
        x_return.append(pic)
    return np.array(x_return)

def compare(model_name):
    batch_size = 10  # test_pic number, you can change the batch_size here!
    x_test = getTestData(batch_size)  # get test images
    predict_model = keras.models.load_model(model_name)
    res_keras = predict_model.predict(x_test)

    # ############################ load keras model to relay IRModule  ##################################
    input_shape = (batch_size,3,224,224)
    output_shape = (batch_size, 1000)

    shape_dict = {"input_1": input_shape}
    target = 'llvm -mcpu=core-avx2'
    ctx = tvm.cpu(0)

    irmod, params = relay.frontend.from_keras(predict_model, shape_dict)

    # ########################## Compile the RelayIR ####################################################
    with autotvm.apply_graph_best("resnet50-imagenet_origin.h5_graph_opt_tuning_2.log"):
        with tvm.transform.PassContext(opt_level=3):
            graph, lib, params = relay.build_module.build(
                irmod, target=target, params=params)

        module = graph_runtime.create(graph, lib, ctx)
        test_x_tvm = x_test.transpose([0, 3, 1, 2])
        dtype = 'float32'
        data = test_x_tvm.astype(dtype)
        module.set_input("input_1", data)
        module.set_input(**params)
        module.run()
        res_tvm = module.get_output(0, tvm.nd.empty(output_shape)).asnumpy()
    # ###########################  calc tvm accuracy   #################################################

    np.testing.assert_allclose(res_keras,res_tvm,rtol=1e-5,atol=1e-5)

if __name__ == '__main__':
    model_name = "resnet50-imagenet_origin.h5"
    compare(model_name)

resnet50-imagenet_origin.h5_graph_opt_tuning_2.log can be accessed with this link:

@comaniac @FrozenGene @merrymercy Can you give me some adivice? Thanks a lot !!!

If you set the batch size be 1, could you have correct result?

The tolerance to set atol=1e-3, rtol=1e-3 is enough.

Awesome! The prediction result turn out to be right when I the batch_size =1.

But Why? can you explain it ?Thanks.

I think it is the same issue : Accuracy drop when use batch and opt_level=3 · Issue #7563 · apache/tvm · GitHub I am investigating it. I wish I could solve it next week.

1 Like

Thanks. Following your hint , I changed the above script and set opt_level=2 , the script crashed. Besides, ‘opt_level =1’ also crashed with the same messages.

The crash messages are as following.

Traceback (most recent call last):  File "bug_simple.py", line 59, in <module>    compare(model_name)
  File "bug_simple.py", line 41, in compare    irmod, target=target, params=params)  File "/workplace/software/tvm/tvm8/python/tvm/relay/build_module.py", line 260, in build    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)  File "/workplace/software/tvm/tvm8/python/tvm/relay/build_module.py", line 127, in build
    self._build(mod, target, target_host)  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)+0xa6) [0x7f3dc5287006]
  [bt] (7) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x89) [0x7f3dc5286d99]
  [bt] (6) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)+0x27) [0x7f3dc5279907]
  [bt] (5) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)+0x18f) [0x7f3dc528137f]
  [bt] (4) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)+0xa6) [0x7f3dc5287006]
  [bt] (3) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x89) [0x7f3dc5286d99]
  [bt] (2) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)+0x27) [0x7f3dc5279907]
  [bt] (1) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)+0x5f3) [0x7f3dc52817e3]
  [bt] (0) /workplace/software/tvm/tvm8/build/libtvm.so(+0x192a9cb) [0x7f3dc542d9cb]
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun    rv = local_pyfunc(*pyargs)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/backend/compile_engine.py", line 284, in lower_call
    best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/backend/compile_engine.py", line 206, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)  File "/workplace/software/tvm/tvm8/python/tvm/relay/op/op.py", line 91, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  [bt] (3) /workplace/software/tvm/tvm8/build/libtvm.so(TVMFuncCall+0x61) [0x7f3dc5430fd1]  [bt] (2) /workplace/software/tvm/tvm8/build/libtvm.so(+0x17f8d3d) [0x7f3dc52fbd3d]
  [bt] (1) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)+0xb1) [0x7f3dc52fbb11]
  [bt] (0) /workplace/software/tvm/tvm8/build/libtvm.so(+0x192a9cb) [0x7f3dc542d9cb]
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)  File "/workplace/software/tvm/tvm8/python/tvm/relay/op/strategy/generic.py", line 663, in _compute_dense
    return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
  File "/workplace/software/tvm/tvm8/python/tvm/autotvm/task/topi_integration.py", line 161, in wrapper    cfg = DispatchContext.current.query(tgt, workload)
  File "/workplace/software/tvm/tvm8/python/tvm/autotvm/task/dispatcher.py", line 76, in query
    ret = self._query_inside(target, workload)
  File "/workplace/software/tvm/tvm8/python/tvm/autotvm/task/dispatcher.py", line 421, in _query_inside
    assert wkl == workload
TVMError: AssertionError

By the way, another similar question can be find with this link Why tuning 2 times, the accuracy of compiled model nearly 0?.

I wish this information could be useful for you to locate the bug.

i guess your post message should be related with AutoTVM’s opt level is 3 by default.

@sqchao

Would you mind helping me to do one experiment ? Just apply this patch to verify whether it works for you under the condition of multi batch:

diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py
index 6492b78d6..0162780c1 100644
--- a/python/tvm/topi/x86/injective.py
+++ b/python/tvm/topi/x86/injective.py
@@ -37,7 +37,7 @@ def schedule_injective_from_existing(sch, out):
          The updated schedule.
     """
     if len(sch[out].op.axis) >= 5:
-        fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
+        fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
         sch[out].parallel(fused)
     elif len(sch[out].op.axis) >= 3:
         fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])

Confirmed, It works.

Thank you very much. Multi batch and opt_level =3 have a right prediction result now.

This is not an official solution but I will try to resolve as soon as possible

@FrozenGene Thanks, Your patch can solve the inconsistent prediction bug when opt_level =3 and muti batch.

However, when I set opt_level =1 or 2, and compiled model with the sentence: with autotvm.apply_graph_best("resnet50-imagenet_origin.h5_graph_opt_tuning_2.log"): , the script crashed. The script can be found in the description above.

Detailed information:

  1. This is not related with batch_size.
  2. opt_level =1 or 2 , and meanwhile using the sentence with autotvm.apply_graph_best(..) ----> crash.
  3. set opt_level =3 or delete the sentence with autotvm.apply_graph_best(..) —> no crash(The result is correct)

The crash messages are as following:

Traceback (most recent call last):
  File "bug_simple.py", line 59, in <module>
    compare(model_name)
  File "bug_simple.py", line 41, in compare
    irmod, target=target, params=params)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/build_module.py", line 260, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/build_module.py", line 127, in build
    self._build(mod, target, target_host)
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)+0xa6) [0x7fb72b709006]
  [bt] (7) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x89) [0x7fb72b708d99]
  [bt] (6) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)+0x27) [0x7fb72b6fb907]
  [bt] (5) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)+0x18f) [0x7fb72b70337f]
  [bt] (4) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)+0xa6) [0x7fb72b709006]
  [bt] (3) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x89) [0x7fb72b708d99]
  [bt] (2) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)+0x27) [0x7fb72b6fb907]
  [bt] (1) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)+0x5f3) [0x7fb72b7037e3]
  [bt] (0) /workplace/software/tvm/tvm8/build/libtvm.so(+0x192a9cb) [0x7fb72b8af9cb]
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/backend/compile_engine.py", line 284, in lower_call
    best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/backend/compile_engine.py", line 206, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/op/op.py", line 91, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  [bt] (3) /workplace/software/tvm/tvm8/build/libtvm.so(TVMFuncCall+0x61) [0x7fb72b8b2fd1]
  [bt] (2) /workplace/software/tvm/tvm8/build/libtvm.so(+0x17f8d3d) [0x7fb72b77dd3d]
  [bt] (1) /workplace/software/tvm/tvm8/build/libtvm.so(tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)+0xb1) [0x7fb72b77db11]
  [bt] (0) /workplace/software/tvm/tvm8/build/libtvm.so(+0x192a9cb) [0x7fb72b8af9cb]
  File "/workplace/software/tvm/tvm8/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/workplace/software/tvm/tvm8/python/tvm/relay/op/strategy/generic.py", line 663, in _compute_dense
    return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
  File "/workplace/software/tvm/tvm8/python/tvm/autotvm/task/topi_integration.py", line 161, in wrapper
    cfg = DispatchContext.current.query(tgt, workload)
  File "/workplace/software/tvm/tvm8/python/tvm/autotvm/task/dispatcher.py", line 76, in query
    ret = self._query_inside(target, workload)
  File "/workplace/software/tvm/tvm8/python/tvm/autotvm/task/dispatcher.py", line 421, in _query_inside
    assert wkl == workload
TVMError: AssertionError

It looks like an another new bug. Can you help me to check it?

The issue of opt level setting be 1/2 is urgent or not? Because opt level setting be 3 could solve your problem. I personally prefer solving multi batch size of opt level 3 problem firstly (official solution, not hack patch here), then solve opt level 1/2 here.

@FrozenGene It’s not urgent for me.