Failed to process onnx where op on Hexagon

I tried to deploy an ONNX model to Hexagon and encounter this error below.

Check failed: (IsPointerType(buffer_var->type_annotation, dtype)) is false: The allocated data type (bool) does not match the type annotation of the buffer fused_constant (T.handle("int8")). The data type should be an element of the pointer type.

I located the op causing the issue, which is op Where, so I make a small model which could reproduce the issue where.onnx. The code is below.

import numpy as np
import pytest

import tvm.testing
from tvm import relay
from tvm.contrib.hexagon.session import Session
from tvm.relay.backend import Executor, Runtime


def get_model():
    onnx = pytest.importorskip("onnx")

    model_path = "where.onnx"
    return onnx.load(model_path)

@tvm.testing.requires_hexagon
def test_where_aot(hexagon_session: Session):
    """Test mobilenet with aot executor"""
    dtype = "float32"
    onnx_model = get_model()

    data_in = np.random.rand(128).astype(dtype=dtype)
    input_name = "Y"
    shape_dict = {input_name: data_in.shape}
    relay_mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
    inputs = {input_name: data_in}

    with tvm.transform.PassContext(opt_level=3):
        hexagon_lowered = tvm.relay.build(
            relay_mod,
            tvm.target.Target(tvm.target.hexagon("v68"), host="c"),
            runtime=Runtime("cpp"),
            executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
            params=params,
        )

    hexagon_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
    hexagon_mod.set_input(**inputs)
    hexagon_mod.run()
    hexagon_output = hexagon_mod.get_output(0).numpy()
    print(hexagon_output)


if __name__ == "__main__":
    tvm.testing.main()

The error log is below.

~/tests/python/contrib/test_hexagon# pytest test_where.py -v -s
[08:04:13] /home/user/workspace/codes/tvm/src/target/target_kind.cc:164: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
enabled targets: llvm; hexagon
pytest marker: 
================================================================================= test session starts =================================================================================
platform linux -- Python 3.8.10, pytest-7.2.1, pluggy-1.0.0 -- /venv/apache-tvm-py3.8/bin/python3
cachedir: .pytest_cache
rootdir: /home/user/workspace/codes/tvm/tests/python/contrib/test_hexagon
plugins: xdist-3.1.0, rerunfailures-10.2, lazy-fixture-0.6.3, profiling-1.7.0
collected 1 item                                                                                                                                                                      

test_where.py::test_where_aot FAILED

====================================================================================== FAILURES =======================================================================================
___________________________________________________________________________________ test_where_aot ____________________________________________________________________________________

hexagon_session = <tvm.contrib.hexagon.session.Session object at 0x7fea07033100>

    @tvm.testing.requires_hexagon
    def test_where_aot(hexagon_session: Session):
        """Test mobilenet with aot executor"""
        dtype = "float32"
        onnx_model = get_model()
    
        data_in = np.random.rand(128).astype(dtype=dtype)
        input_name = "Y"
        shape_dict = {input_name: data_in.shape}
        relay_mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
        inputs = {input_name: data_in}
    
        with tvm.transform.PassContext(opt_level=3):
>           hexagon_lowered = tvm.relay.build(
                relay_mod,
                tvm.target.Target(tvm.target.hexagon("v68"), host="c"),
                runtime=Runtime("cpp"),
                executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
                params=params,
            )

test_where.py:48: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../../python/tvm/relay/build_module.py:364: in build
    graph_json, runtime_mod, params = bld_mod.build(
../../../../python/tvm/relay/build_module.py:161: in build
    self._build(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <tvm.runtime.packed_func.PackedFunc object at 0x7fea3117e280>
args = (def @main(%Y: Tensor[(128), float32] /* ty=Tensor[(128), float32] span=Where_0.Y:0:0 */) -> Tensor[(128), float32] {
...y=0], None, aot{"interface-api": "packed", "link-params": T.bool(True), "unpacked-api": T.bool(False)}, cpp, None, ...)
temp_args = [], values = <tvm._ffi._ctypes.packed_func.TVMValue_Array_8 object at 0x7fea311e9540>, tcodes = <tvm._ffi._ctypes.packed_func.c_int_Array_8 object at 0x7fea311e9ac0>

    def __call__(self, *args):
        """Call the function with positional arguments
    
        args : list
           The positional arguments to the function call.
        """
        temp_args = []
        values, tcodes, num_args = _make_tvm_args(args, temp_args)
        ret_val = TVMValue()
        ret_tcode = ctypes.c_int()
        if (
            _LIB.TVMFuncCall(
                self.handle,
                values,
                tcodes,
                ctypes.c_int(num_args),
                ctypes.byref(ret_val),
                ctypes.byref(ret_tcode),
            )
            != 0
        ):
>           raise get_last_ffi_error()
E           tvm._ffi.base.TVMError: Traceback (most recent call last):
E             41: TVMFuncCall
E             40: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
E             39: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
E             38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
E             37: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
E             36: tvm::transform::Pass::operator()(tvm::IRModule) const
E             35: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             34: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             33: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             32: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             31: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
E             30: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
E             29: tvm::transform::Pass::operator()(tvm::IRModule) const
E             28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             27: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             26: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SM_SQ_
E             25: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
E             24: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
E             23: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
E             22: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
E             21: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
E             20: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
E             19: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
E             18: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
E             17: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*)
E             16: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&)
E             15: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
E             14: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
E             13: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
E             12: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
E             11: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
E             10: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
E             9: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
E             8: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&)
E             7: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, tvm::GlobalVarSupply)
E             6: tvm::transform::Pass::operator()(tvm::IRModule) const
E             5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             4: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
E             3: _ZN3tvm7runtime13PackedFuncObj
E             2: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::BindParams(tvm::runtime::Array<tvm::runtime::NDArray, void> const&)::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::BindParams(tvm::runtime::Array<tvm::runtime::NDArray, void> const&)::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
E             1: tvm::tir::BindParams(tvm::tir::PrimFunc, tvm::runtime::Array<tvm::runtime::NDArray, void> const&)
E             0: tvm::tir::AllocateConst::AllocateConst(tvm::tir::Var, tvm::runtime::DataType, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::ObjectRef, tvm::tir::Stmt, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::ObjectRef, void, void>, tvm::Span)
E             File "/home/user/workspace/codes/tvm/src/tir/ir/stmt.cc", line 275
E           TVMError: 
E           ---------------------------------------------------------------
E           An error occurred during the execution of TVM.
E           For more information, please see: https://tvm.apache.org/docs/errors.html
E           ---------------------------------------------------------------
E             Check failed: (IsPointerType(buffer_var->type_annotation, dtype)) is false: The allocated data type (bool) does not match the type annotation of the buffer fused_constant (T.handle("int8")). The data type should be an element of the pointer type.

../../../../python/tvm/_ffi/_ctypes/packed_func.py:237: TVMError
=============================================================================== short test summary info ===============================================================================
FAILED test_where.py::test_where_aot - tvm._ffi.base.TVMError: Traceback (most recent call last):
================================================================================= 1 failed in 26.49s ==================================================================================

Hi @libin, Where you able to resolve this issue?

I didn’t resolve the issue. I worked around the issue by removing the where-op.