Error - Running an ONNX model with DNNL

Hi, I’ve been trying to test TVM while using a DNNL codegen, and I’ve been running into some errors when trying to do so with an ONNX model.

../tests/python/relay/test_dnnl_onnx.py:35: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the  new recommended usage.
  graph, lib, param = relay.build(mod, target=target, params=params, runtime=runtime)
Traceback (most recent call last):
  File "../tests/python/relay/test_dnnl_onnx.py", line 86, in <module>
    test_extern_onnx_model(onnx_model_path, codegen_str="dnnl", ishape= onnx_model_ishape, oshape = onnx_model_oshape, input_name = onnx_model_input_name)
  File "../tests/python/relay/test_dnnl_onnx.py", line 67, in test_extern_onnx_model
    check_result_no_compare(mod, {input_name: i_data}, oshape, params=params)
  File "../tests/python/relay/test_dnnl_onnx.py", line 48, in check_result_no_compare
    check_graph_executor_result()
  File "../tests/python/relay/test_dnnl_onnx.py", line 37, in check_graph_executor_result
    rt_mod = tvm.contrib.graph_executor.create(graph, lib, device)
  File "/projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/python/tvm/contrib/graph_executor.py", line 66, in create
    return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
  File "/projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::json::JSONRuntimeBase::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#4}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x38) [0x7f479360d7b6]
  [bt] (7) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::json::JSONRuntimeBase::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#4}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x187) [0x7f47935eb47d]
  [bt] (6) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::contrib::DNNLJSONRuntime::Init(tvm::runtime::Array<tvm::runtime::NDArray, void> const&)+0x30) [0x7f47935f0474]
  [bt] (5) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::contrib::DNNLJSONRuntime::BuildEngine()+0x5c7) [0x7f47935f18f1]
  [bt] (4) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::contrib::DNNLJSONRuntime::Binary(unsigned long const&, dnnl::algorithm)+0x335) [0x7f47935faacb]
  [bt] (3) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::contrib::DNNLJSONRuntime::GenDNNLMemDescByShape(std::vector<long, std::allocator<long> > const&, dnnl::memory::data_type)+0x21f) [0x7f47935fb45f]
  [bt] (2) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::detail::LogFatal::~LogFatal()+0x36) [0x7f47916bdbd6]
  [bt] (1) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x4a) [0x7f47916bdd4a]
  [bt] (0) /projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/build/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x35) [0x7f47934c3e05]
  File "/projects/vbu_projects/users/guyz/DockerTVM/vbu_tvm_dsp_backend/src/runtime/contrib/dnnl/dnnl_json_runtime.cc", line 843
TVMError: Unsupported data shape dimension: 0

The error can be reproduced when using the following code:

import os
import sys
import numpy as np
import tvm
from tvm.relay.backend import te_compiler
from tvm.relay.backend.runtime import Runtime
import tvm.relay.testing
from tvm import relay
from tvm import runtime as tvm_runtime
from tvm.relay import transform
from tvm.contrib import utils
import onnx
from tvm.contrib.download import download_testdata


def check_result_no_compare(mod, map_inputs, out_shape, target="llvm", device=tvm.cpu(), params=None, runtime=Runtime("cpp")):
    def update_lib(lib):
        test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
        source_dir = os.path.join(test_dir, "..", "..", "..")
        contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
        kwargs = {}
        kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
        tmp_path = utils.tempdir()
        lib_name = "lib.so"
        lib_path = tmp_path.relpath(lib_name)
        lib.export_library(lib_path, fcompile=False, **kwargs)
        lib = tvm_runtime.load_module(lib_path)

        return lib

    
    def check_graph_executor_result():
        te_compiler.get().clear()
        with tvm.transform.PassContext(opt_level=3):
            graph, lib, param = relay.build(mod, target=target, params=params, runtime=runtime)
        lib = update_lib(lib) 
        rt_mod = tvm.contrib.graph_executor.create(graph, lib, device)
        for name, data in map_inputs.items():
            rt_mod.set_input(name, data)
        rt_mod.set_input(**param)
        rt_mod.run()
        out_shapes = out_shape if isinstance(out_shape, list) else [out_shape]

        for idx, shape in enumerate(out_shapes):
            out = tvm.nd.empty(shape, device=device)
            out = rt_mod.get_output(idx, out)

    check_graph_executor_result()


def test_extern_onnx_model(path, codegen_str="cdnn", ishape=(1, 3, 224, 224), oshape = (1, 32, 224, 224), input_name = "input"):
    if not tvm.get_global_func("relay.ext.dnnl", True):
        print("skip because DNNL codegen is not available")
        return

    dtype = "float32"
    
    block = onnx.load(path)
    shape_dict = {input_name: ishape}
    mod, params = relay.frontend.from_onnx(block, shape_dict)

    mod = transform.AnnotateTarget([codegen_str])(mod) # Comment this line for this code to run
    mod = transform.MergeCompilerRegions()(mod) # Comment this line for this code to run
    mod = transform.PartitionGraph()(mod) # Comment this line for this code to run
    i_data = np.random.uniform(0, 1, ishape).astype(dtype)

    check_result_no_compare(mod, {input_name: i_data}, oshape, params=params)



if __name__ == "__main__":
    codegen = "dnnl"
    model = "resnet50-v2-7"
    model_url = (
    "https://github.com/onnx/models/raw/main/"
    "vision/classification/resnet/model/"
    "resnet50-v2-7.onnx"
    )
    onnx_model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx")
    onnx_model_ishape = (1, 3, 224, 224)
    onnx_model_oshape = (1, 1000)
    onnx_model_input_name = "data"

     
    print("Testing " + model + " (" + onnx_model_path + ") with " + codegen, flush=True)
    test_extern_onnx_model(onnx_model_path, codegen_str="dnnl", ishape= onnx_model_ishape, oshape = onnx_model_oshape, input_name = onnx_model_input_name)

Note I’ve been able to successfully run an MXNet model while offloading it to DNNL, and also note this same model will run if you don’t offload it to DNNL (which you can do if you comment the three lines I marked in test_extern_onnx_model), so I don’t think it’s a problem with this ONNX model or with using the DNNL codegen, but rather the combination of the two.

Help will be greatly appriciated.